Skip to content

Commit

Permalink
Do not send msg3 if key confirm fails in PASE (#7502)
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-apple authored Jun 10, 2021
1 parent 1f1b715 commit e0c7b6d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,19 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle
VerifyOrExit(CanCastTo<uint16_t>(verifier_len_raw), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);
verifier_len = static_cast<uint16_t>(verifier_len_raw);

{
const uint8_t * hash = &buf[kMAX_Point_Length];
err = mSpake2p.KeyConfirm(hash, kMAX_Hash_Length);
if (err != CHIP_NO_ERROR)
{
spake2pErr = Spake2pErrorType::kInvalidKeyConfirmation;
SuccessOrExit(err);
}

err = mSpake2p.GetKeys(mKe, &mKeLen);
SuccessOrExit(err);
}

{
Encoding::PacketBufferWriter bbuf(System::PacketBufferHandle::New(verifier_len));
VerifyOrExit(!bbuf.IsNull(), err = CHIP_SYSTEM_ERROR_NO_MEMORY);
Expand All @@ -630,19 +643,6 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle

ChipLogDetail(SecureChannel, "Sent spake2p msg3");

{
const uint8_t * hash = &buf[kMAX_Point_Length];
err = mSpake2p.KeyConfirm(hash, kMAX_Hash_Length);
if (err != CHIP_NO_ERROR)
{
spake2pErr = Spake2pErrorType::kInvalidKeyConfirmation;
SuccessOrExit(err);
}

err = mSpake2p.GetKeys(mKe, &mKeLen);
SuccessOrExit(err);
}

mPairingComplete = true;

// Close the exchange, as no additional messages are expected from the peer
Expand Down
50 changes: 50 additions & 0 deletions src/protocols/secure_channel/tests/TestPASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,55 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo
NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0);
}

void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext)
{
TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;

TestSecurePairingDelegate delegateAccessory;
PASESession pairingAccessory;

gLoopback.Reset();
gLoopback.mSentMessageCount = 0;

NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);

TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner);

pairingCommissioner.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp));
pairingAccessory.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp));

ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = contextCommissioner->GetReliableMessageContext();
NL_TEST_ASSERT(inSuite, rm != nullptr);
NL_TEST_ASSERT(inSuite, rc != nullptr);

rc->SetConfig({
1, // CHIP_CONFIG_MRP_DEFAULT_INITIAL_RETRY_INTERVAL
1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
});
gLoopback.mContext = &ctx;

NL_TEST_ASSERT(inSuite,
ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(
Protocols::SecureChannel::MsgType::PBKDFParamRequest, &pairingAccessory) == CHIP_NO_ERROR);

NL_TEST_ASSERT(inSuite,
pairingAccessory.WaitForPairing(1234, 500, (const uint8_t *) "saltSALT", 8, 0, &delegateAccessory) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite,
pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, contextCommissioner,
&delegateCommissioner) == CHIP_NO_ERROR);

gLoopback.mContext = nullptr;
NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 0);
NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingErrors == 1);
NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 0);
NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingErrors == 1);
}

void SecurePairingDeserialize(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner,
PASESession & deserialized)
{
Expand Down Expand Up @@ -323,6 +372,7 @@ static const nlTest sTests[] =
NL_TEST_DEF("Start", SecurePairingStartTest),
NL_TEST_DEF("Handshake", SecurePairingHandshakeTest),
NL_TEST_DEF("Handshake with packet loss", SecurePairingHandshakeWithPacketLossTest),
NL_TEST_DEF("Failed Handshake", SecurePairingFailedHandshake),
NL_TEST_DEF("Serialize", SecurePairingSerializeTest),

NL_TEST_SENTINEL()
Expand Down

0 comments on commit e0c7b6d

Please sign in to comment.