diff --git a/integration_tests/ssh3_test.go b/integration_tests/ssh3_test.go index 942a32a..604ab8a 100644 --- a/integration_tests/ssh3_test.go +++ b/integration_tests/ssh3_test.go @@ -20,7 +20,8 @@ var ssh3ServerPath string const DEFAULT_URL_PATH = "/ssh3-tests" var serverCommand *exec.Cmd var serverSession *Session -var privKeyPath string +var rsaPrivKeyPath string +var ed25519PrivKeyPath string var attackerPrivKeyPath string var username string // must exist on the machine to successfully run the tests @@ -65,10 +66,11 @@ var _ = BeforeSuite(func() { serverSession, err = Start(serverCommand, GinkgoWriter, GinkgoWriter) Expect(err).ToNot(HaveOccurred()) - privKeyPath = os.Getenv("TESTUSER_PRIVKEY") + rsaPrivKeyPath = os.Getenv("TESTUSER_PRIVKEY") + ed25519PrivKeyPath = os.Getenv("TESTUSER_ED25519_PRIVKEY") attackerPrivKeyPath = os.Getenv("ATTACKER_PRIVKEY") username = os.Getenv("TESTUSER_USERNAME") - Expect(fileExists(privKeyPath)).To(BeTrue()) + Expect(fileExists(rsaPrivKeyPath)).To(BeTrue()) Expect(fileExists(attackerPrivKeyPath)).To(BeTrue()) } }) @@ -113,8 +115,17 @@ var _ = Describe("Testing the ssh3 cli", func() { } Context("Client behaviour", func() { - It("Should connect using privkey", func() { - clientArgs = append(getClientArgs(privKeyPath), "echo", "Hello, World!") + It("Should connect using an RSA privkey", func() { + clientArgs = append(getClientArgs(rsaPrivKeyPath), "echo", "Hello, World!") + command := exec.Command(ssh3Path, clientArgs...) + session, err := Start(command, GinkgoWriter, GinkgoWriter) + Expect(err).ToNot(HaveOccurred()) + Eventually(session).Should(Exit(0)) + Eventually(session).Should(Say("Hello, World!\n")) + }) + + It("Should connect using an ed25519 privkey", func() { + clientArgs = append(getClientArgs(ed25519PrivKeyPath), "echo", "Hello, World!") command := exec.Command(ssh3Path, clientArgs...) session, err := Start(command, GinkgoWriter, GinkgoWriter) Expect(err).ToNot(HaveOccurred()) @@ -123,10 +134,10 @@ var _ = Describe("Testing the ssh3 cli", func() { }) It("Should return the correct exit status", func() { - clientArgs0 := append(getClientArgs(privKeyPath), "exit", "0") - clientArgs1 := append(getClientArgs(privKeyPath), "exit", "1") - clientArgs255 := append(getClientArgs(privKeyPath), "exit", "255") - clientArgsMinus1 := append(getClientArgs(privKeyPath), "exit", "-1") + clientArgs0 := append(getClientArgs(rsaPrivKeyPath), "exit", "0") + clientArgs1 := append(getClientArgs(rsaPrivKeyPath), "exit", "1") + clientArgs255 := append(getClientArgs(rsaPrivKeyPath), "exit", "255") + clientArgsMinus1 := append(getClientArgs(rsaPrivKeyPath), "exit", "-1") command0 := exec.Command(ssh3Path, clientArgs0...) session, err := Start(command0, GinkgoWriter, GinkgoWriter) @@ -192,7 +203,7 @@ var _ = Describe("Testing the ssh3 cli", func() { Eventually(serverStarted).Should(Receive()) // Execute the client with TCP port forwarding - clientArgs := getClientArgs(privKeyPath, "-forward-tcp", fmt.Sprintf("%d/%s@%d", localPort, remoteAddr.IP, remoteAddr.Port)) + clientArgs := getClientArgs(rsaPrivKeyPath, "-forward-tcp", fmt.Sprintf("%d/%s@%d", localPort, remoteAddr.IP, remoteAddr.Port)) command := exec.Command(ssh3Path, clientArgs...) session, err := Start(command, GinkgoWriter, GinkgoWriter) Expect(err).ToNot(HaveOccurred()) @@ -301,7 +312,7 @@ var _ = Describe("Testing the ssh3 cli", func() { Eventually(serverStarted).Should(Receive()) // Execute the client with UDP port forwarding - clientArgs := getClientArgs(privKeyPath, "-forward-udp", fmt.Sprintf("%d/%s@%d", localPort, remoteAddr.IP, remoteAddr.Port)) + clientArgs := getClientArgs(rsaPrivKeyPath, "-forward-udp", fmt.Sprintf("%d/%s@%d", localPort, remoteAddr.IP, remoteAddr.Port)) command := exec.Command(ssh3Path, clientArgs...) session, err := Start(command, GinkgoWriter, GinkgoWriter) Expect(err).ToNot(HaveOccurred()) diff --git a/util/util.go b/util/util.go index 86f5f93..8d9c2a4 100644 --- a/util/util.go +++ b/util/util.go @@ -22,6 +22,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "golang.org/x/crypto/cryptobyte" cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1" ) @@ -164,13 +165,31 @@ func (q *DatagramsQueue) WaitNext(ctx context.Context) ([]byte, error) { } func JWTSigningMethodFromCryptoPubkey(pubkey crypto.PublicKey) (jwt.SigningMethod, error) { + log.Debug().Type("SigningMethodType", pubkey).Msg("fetching singing method from crypto.PublicKey") + switch pubkey.(type) { case *rsa.PublicKey: + log. + Trace(). + Type("SigningMethodType", pubkey). + Str("FoundSigningMethod", "RSA"). + Msg("found public key type") return jwt.SigningMethodRS256, nil - case *ed25519.PublicKey: + case ed25519.PublicKey: + log. + Trace(). + Type("SigningMethodType", pubkey). + Str("FoundSigningMethod", "ED25519"). + Msg("found public key type") return jwt.SigningMethodEdDSA, nil + default: + log. + Error(). + Type("SigningMethodType", pubkey). + Str("FoundSigningMethod", "unknown"). + Msg("did not find public key type") + return nil, UnknownSSHPubkeyType{pubkey: pubkey} } - return nil, UnknownSSHPubkeyType{pubkey: pubkey} } func Sha256Fingerprint(in []byte) string {