diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 433e027bdeb14a..881c92a4c81d1b 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -616,6 +616,19 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle VerifyOrExit(CanCastTo(verifier_len_raw), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); verifier_len = static_cast(verifier_len_raw); + { + const uint8_t * hash = &buf[kMAX_Point_Length]; + err = mSpake2p.KeyConfirm(hash, kMAX_Hash_Length); + if (err != CHIP_NO_ERROR) + { + spake2pErr = Spake2pErrorType::kInvalidKeyConfirmation; + SuccessOrExit(err); + } + + err = mSpake2p.GetKeys(mKe, &mKeLen); + SuccessOrExit(err); + } + { Encoding::PacketBufferWriter bbuf(System::PacketBufferHandle::New(verifier_len)); VerifyOrExit(!bbuf.IsNull(), err = CHIP_SYSTEM_ERROR_NO_MEMORY); @@ -630,19 +643,6 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle ChipLogDetail(SecureChannel, "Sent spake2p msg3"); - { - const uint8_t * hash = &buf[kMAX_Point_Length]; - err = mSpake2p.KeyConfirm(hash, kMAX_Hash_Length); - if (err != CHIP_NO_ERROR) - { - spake2pErr = Spake2pErrorType::kInvalidKeyConfirmation; - SuccessOrExit(err); - } - - err = mSpake2p.GetKeys(mKe, &mKeLen); - SuccessOrExit(err); - } - mPairingComplete = true; // Close the exchange, as no additional messages are expected from the peer diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 48dfc4abfd5204..2176b7c328f98f 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -251,6 +251,55 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0); } +void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + + TestSecurePairingDelegate delegateAccessory; + PASESession pairingAccessory; + + gLoopback.Reset(); + gLoopback.mSentMessageCount = 0; + + NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + + TestContext & ctx = *reinterpret_cast(inContext); + ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner); + + pairingCommissioner.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); + pairingAccessory.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); + + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + ReliableMessageContext * rc = contextCommissioner->GetReliableMessageContext(); + NL_TEST_ASSERT(inSuite, rm != nullptr); + NL_TEST_ASSERT(inSuite, rc != nullptr); + + rc->SetConfig({ + 1, // CHIP_CONFIG_MRP_DEFAULT_INITIAL_RETRY_INTERVAL + 1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); + gLoopback.mContext = &ctx; + + NL_TEST_ASSERT(inSuite, + ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( + Protocols::SecureChannel::MsgType::PBKDFParamRequest, &pairingAccessory) == CHIP_NO_ERROR); + + NL_TEST_ASSERT(inSuite, + pairingAccessory.WaitForPairing(1234, 500, (const uint8_t *) "saltSALT", 8, 0, &delegateAccessory) == + CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, contextCommissioner, + &delegateCommissioner) == CHIP_NO_ERROR); + + gLoopback.mContext = nullptr; + NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 0); + NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingErrors == 1); + NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 0); + NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingErrors == 1); +} + void SecurePairingDeserialize(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner, PASESession & deserialized) { @@ -323,6 +372,7 @@ static const nlTest sTests[] = NL_TEST_DEF("Start", SecurePairingStartTest), NL_TEST_DEF("Handshake", SecurePairingHandshakeTest), NL_TEST_DEF("Handshake with packet loss", SecurePairingHandshakeWithPacketLossTest), + NL_TEST_DEF("Failed Handshake", SecurePairingFailedHandshake), NL_TEST_DEF("Serialize", SecurePairingSerializeTest), NL_TEST_SENTINEL()