diff --git a/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java b/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java index 860af0c5..d156cb32 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java @@ -34,13 +34,14 @@ public class Curve25519DH extends DHBase { private static final String ALGORITHM = "X25519"; - private static final int ENCODED_ALGORITHM_ID_KEY_LENGTH = 44; + private static final int KEY_LENGTH = 32; - private static final int ALGORITHM_ID_LENGTH = 12; + private int encodedKeyLength; - private static final int KEY_LENGTH = 32; + private int algorithmIdLength; - private final byte[] algorithmId = new byte[ALGORITHM_ID_LENGTH]; + // Algorithm Identifier is set on Key Agreement Initialization + private byte[] algorithmId = new byte[KEY_LENGTH]; public Curve25519DH() { super(ALGORITHM, ALGORITHM); @@ -81,23 +82,24 @@ public void init(final AlgorithmParameterSpec params, final Factory rand private void setPublicKey(final PublicKey publicKey) { final byte[] encoded = publicKey.getEncoded(); + // Set key and algorithm identifier lengths based on initialized Public Key + encodedKeyLength = encoded.length; + algorithmIdLength = encodedKeyLength - KEY_LENGTH; + algorithmId = new byte[algorithmIdLength]; + // Encoded public key consists of the algorithm identifier and public key - if (encoded.length == ENCODED_ALGORITHM_ID_KEY_LENGTH) { - final byte[] publicKeyEncoded = new byte[KEY_LENGTH]; - System.arraycopy(encoded, ALGORITHM_ID_LENGTH, publicKeyEncoded, 0, KEY_LENGTH); - setE(publicKeyEncoded); - - // Save Algorithm Identifier byte array - System.arraycopy(encoded, 0, algorithmId, 0, ALGORITHM_ID_LENGTH); - } else { - throw new IllegalArgumentException(String.format("X25519 unsupported public key length [%d]", encoded.length)); - } + final byte[] publicKeyEncoded = new byte[KEY_LENGTH]; + System.arraycopy(encoded, algorithmIdLength, publicKeyEncoded, 0, KEY_LENGTH); + setE(publicKeyEncoded); + + // Save Algorithm Identifier byte array + System.arraycopy(encoded, 0, algorithmId, 0, algorithmIdLength); } private KeySpec getPeerPublicKeySpec(final byte[] peerPublicKey) { - final byte[] encodedKeySpec = new byte[ENCODED_ALGORITHM_ID_KEY_LENGTH]; - System.arraycopy(algorithmId, 0, encodedKeySpec, 0, ALGORITHM_ID_LENGTH); - System.arraycopy(peerPublicKey, 0, encodedKeySpec, ALGORITHM_ID_LENGTH, KEY_LENGTH); + final byte[] encodedKeySpec = new byte[encodedKeyLength]; + System.arraycopy(algorithmId, 0, encodedKeySpec, 0, algorithmIdLength); + System.arraycopy(peerPublicKey, 0, encodedKeySpec, algorithmIdLength, KEY_LENGTH); return new X509EncodedKeySpec(encodedKeySpec); } } diff --git a/src/test/java/net/schmizz/sshj/transport/kex/Curve25519DHTest.java b/src/test/java/net/schmizz/sshj/transport/kex/Curve25519DHTest.java index 3e2e2b3b..7c9cbd47 100644 --- a/src/test/java/net/schmizz/sshj/transport/kex/Curve25519DHTest.java +++ b/src/test/java/net/schmizz/sshj/transport/kex/Curve25519DHTest.java @@ -15,17 +15,24 @@ */ package net.schmizz.sshj.transport.kex; +import net.schmizz.sshj.common.SecurityUtils; import net.schmizz.sshj.transport.random.JCERandom; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.math.BigInteger; import java.security.GeneralSecurityException; +import java.security.KeyPairGenerator; +import java.security.Provider; +import java.security.Security; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; public class Curve25519DHTest { + private static final String ALGORITHM_FILTER = "KeyPairGenerator.X25519"; + private static final int KEY_LENGTH = 32; private static final byte[] PEER_PUBLIC_KEY = { @@ -35,8 +42,16 @@ public class Curve25519DHTest { 1, 2, 3, 4, 5, 6, 7, 8 }; + @BeforeEach + public void clearSecurityProvider() { + SecurityUtils.setSecurityProvider(null); + } + @Test public void testInitPublicKeyLength() throws GeneralSecurityException { + final boolean bouncyCastleRegistrationRequired = isAlgorithmUnsupported(); + SecurityUtils.setRegisterBouncyCastle(bouncyCastleRegistrationRequired); + final Curve25519DH dh = new Curve25519DH(); dh.init(null, new JCERandom.Factory()); @@ -48,6 +63,8 @@ public void testInitPublicKeyLength() throws GeneralSecurityException { @Test public void testInitComputeSharedSecretKey() throws GeneralSecurityException { + SecurityUtils.setRegisterBouncyCastle(true); + final Curve25519DH dh = new Curve25519DH(); dh.init(null, new JCERandom.Factory()); @@ -57,4 +74,9 @@ public void testInitComputeSharedSecretKey() throws GeneralSecurityException { assertNotNull(sharedSecretKey); assertEquals(BigInteger.ONE.signum(), sharedSecretKey.signum()); } + + private boolean isAlgorithmUnsupported() { + final Provider[] providers = Security.getProviders(ALGORITHM_FILTER); + return providers == null || providers.length == 0; + } }