diff --git a/examples/common/chip-app-server/Server.cpp b/examples/common/chip-app-server/Server.cpp index d3637fe9eb2f59..ec61b767abe18a 100644 --- a/examples/common/chip-app-server/Server.cpp +++ b/examples/common/chip-app-server/Server.cpp @@ -62,9 +62,10 @@ class ServerCallback : public SecureSessionMgrDelegate { public: void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle buffer, + SecureSessionHandle session, System::PacketBufferHandle buffer, SecureSessionMgr * mgr) override { + auto state = mgr->GetPeerConnectionState(session); const size_t data_len = buffer->DataLength(); char src_addr[PeerAddress::kMaxToStringSize]; @@ -92,7 +93,7 @@ class ServerCallback : public SecureSessionMgrDelegate } } - void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override + void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override { ChipLogProgress(AppServer, "Received a new connection."); } @@ -181,11 +182,10 @@ void InitServer(AppDelegate * delegate) SuccessOrExit(err); #endif + gSessions.SetDelegate(&gCallbacks); err = gSessions.NewPairing(peer, chip::kTestControllerNodeId, &gTestPairing); SuccessOrExit(err); - gSessions.SetDelegate(&gCallbacks); - exit: if (err != CHIP_NO_ERROR) { diff --git a/src/app/util/chip-message-send.cpp b/src/app/util/chip-message-send.cpp index 12aa752106021b..5f29ed65b04fa6 100644 --- a/src/app/util/chip-message-send.cpp +++ b/src/app/util/chip-message-send.cpp @@ -73,7 +73,8 @@ EmberStatus chipSendUnicast(NodeId destination, EmberApsFrame * apsFrame, uint16 memcpy(buffer->Start() + frameSize, message, messageLength); buffer->SetDataLength(dataLength); - CHIP_ERROR err = SessionManager().SendMessage(destination, std::move(buffer)); + // TODO: temprary create a handle from node id, will be fix in PR 3602 + CHIP_ERROR err = SessionManager().SendMessage({destination, Transport::kAnyKeyId}, std::move(buffer)); if (err != CHIP_NO_ERROR) { // FIXME: Figure out better translations between our error types? diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index 4e83761ed49dd6..25f3dcac6eba87 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -74,7 +74,7 @@ CHIP_ERROR Device::SendMessage(System::PacketBufferHandle buffer) resend = buffer.Retain(); } - err = mSessionManager->SendMessage(mDeviceId, std::move(buffer)); + err = mSessionManager->SendMessage(mSecureSession, std::move(buffer)); buffer = nullptr; ChipLogDetail(Controller, "SendMessage returned %d", err); @@ -87,7 +87,7 @@ CHIP_ERROR Device::SendMessage(System::PacketBufferHandle buffer) err = LoadSecureSessionParameters(ResetTransport::kYes); SuccessOrExit(err); - err = mSessionManager->SendMessage(mDeviceId, std::move(resend)); + err = mSessionManager->SendMessage(mSecureSession, std::move(resend)); ChipLogDetail(Controller, "Re-SendMessage returned %d", err); SuccessOrExit(err); } @@ -175,8 +175,20 @@ CHIP_ERROR Device::Deserialize(const SerializedDevice & input) return error; } +void Device::OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) +{ + mState = ConnectionState::SecureConnected; + mSecureSession = session; +} + +void Device::OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) +{ + mState = ConnectionState::NotConnected; + mSecureSession = SecureSessionHandle{}; +} + void Device::OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf, + SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) { if (mState == ConnectionState::SecureConnected) @@ -255,8 +267,6 @@ CHIP_ERROR Device::LoadSecureSessionParameters(ResetTransport resetNeeded) &pairingSession); SuccessOrExit(err); - mState = ConnectionState::SecureConnected; - exit: if (err != CHIP_NO_ERROR) diff --git a/src/controller/CHIPDevice.h b/src/controller/CHIPDevice.h index 7484608624e5d6..3b9494feae525e 100644 --- a/src/controller/CHIPDevice.h +++ b/src/controller/CHIPDevice.h @@ -156,6 +156,26 @@ class DLL_EXPORT Device **/ CHIP_ERROR Deserialize(const SerializedDevice & input); + /** + * @brief + * Called when a new pairing is being established + * + * @param session A handle to the secure session + * @param mgr A pointer to the SecureSessionMgr + */ + void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr); + + /** + * @brief + * Called when a connection is closing. + * + * The receiver should release all resources associated with the connection. + * + * @param session A handle to the secure session + * @param mgr A pointer to the SecureSessionMgr + */ + void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr); + /** * @brief * This function is called when a message is received from the corresponding CHIP @@ -164,12 +184,12 @@ class DLL_EXPORT Device * * @param[in] header Reference to common packet header of the received message * @param[in] payloadHeader Reference to payload header in the message - * @param[in] state Pointer to the peer connection state on which message is received + * @param[in] session A handle to the secure session * @param[in] msgBuf The message buffer * @param[in] mgr Pointer to secure session manager which received the message */ void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr); + SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr); /** * @brief @@ -180,6 +200,8 @@ class DLL_EXPORT Device void SetActive(bool active) { mActive = active; } + bool IsSecureConnected() const { return IsActive() && mState == ConnectionState::SecureConnected; } + void Reset() { SetActive(false); @@ -191,6 +213,8 @@ class DLL_EXPORT Device NodeId GetDeviceId() const { return mDeviceId; } + bool MatchesSession(SecureSessionHandle session) const { return mSecureSession == session; } + void SetAddress(const Inet::IPAddress & deviceAddr) { mDeviceAddr = deviceAddr; } SecurePairingSessionSerializable & GetPairing() { return mPairing; } @@ -242,6 +266,8 @@ class DLL_EXPORT Device DeviceTransportMgr * mTransportMgr; + SecureSessionHandle mSecureSession = {}; + /* Track all outstanding response callbacks for this device. The callbacks are registered when a command is sent to the device, to get notified with the results. */ Callback::CallbackDeque mResponses; diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 3d55deb6b29250..9e1b466fab1ea2 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -334,24 +334,58 @@ CHIP_ERROR DeviceController::ServiceEventSignal() return err; } -void DeviceController::OnNewConnection(const Transport::PeerConnectionState * peerConnection, SecureSessionMgr * mgr) {} +void DeviceController::OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + uint16_t index = 0; + + VerifyOrExit(mState == State::Initialized, err = CHIP_ERROR_INCORRECT_STATE); + + index = FindDeviceIndex(mgr->GetPeerConnectionState(session)->GetPeerNodeId()); + VerifyOrExit(index < kNumMaxActiveDevices, err = CHIP_ERROR_INVALID_DEVICE_DESCRIPTOR); + + mActiveDevices[index].OnNewConnection(session, mgr); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(Controller, "Failed to process received message: err %d", err); + } +} + +void DeviceController::OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + uint16_t index = 0; + + VerifyOrExit(mState == State::Initialized, err = CHIP_ERROR_INCORRECT_STATE); + + index = FindDeviceIndex(session); + VerifyOrExit(index < kNumMaxActiveDevices, err = CHIP_ERROR_INVALID_DEVICE_DESCRIPTOR); + + mActiveDevices[index].OnConnectionExpired(session, mgr); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(Controller, "Failed to process received message: err %d", err); + } +} void DeviceController::OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf, + SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) { CHIP_ERROR err = CHIP_NO_ERROR; uint16_t index = 0; - NodeId peer; VerifyOrExit(mState == State::Initialized, err = CHIP_ERROR_INCORRECT_STATE); VerifyOrExit(header.GetSourceNodeId().HasValue(), err = CHIP_ERROR_INVALID_ARGUMENT); - peer = header.GetSourceNodeId().Value(); - index = FindDeviceIndex(peer); + index = FindDeviceIndex(session); VerifyOrExit(index < kNumMaxActiveDevices, err = CHIP_ERROR_INVALID_DEVICE_DESCRIPTOR); - mActiveDevices[index].OnMessageReceived(header, payloadHeader, state, std::move(msgBuf), mgr); + mActiveDevices[index].OnMessageReceived(header, payloadHeader, session, std::move(msgBuf), mgr); exit: if (err != CHIP_NO_ERROR) @@ -395,6 +429,20 @@ void DeviceController::ReleaseAllDevices() } } +uint16_t DeviceController::FindDeviceIndex(SecureSessionHandle session) +{ + uint16_t i = 0; + while (i < kNumMaxActiveDevices) + { + if (mActiveDevices[i].IsActive() && mActiveDevices[i].IsSecureConnected() && mActiveDevices[i].MatchesSession(session)) + { + return i; + } + i++; + } + return i; +} + uint16_t DeviceController::FindDeviceIndex(NodeId id) { uint16_t i = 0; diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index c12c2881117496..340e7803e5a91a 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -192,17 +192,19 @@ class DLL_EXPORT DeviceController : public SecureSessionMgrDelegate, public Pers uint16_t mListenPort; uint16_t GetInactiveDeviceIndex(); - uint16_t FindDeviceIndex(NodeId id); + uint16_t FindDeviceIndex(SecureSessionHandle session); + [[deprecated("only peer node id is not sufficient to identify a device")]] uint16_t FindDeviceIndex(NodeId id); void ReleaseDevice(uint16_t index); CHIP_ERROR SetPairedDeviceList(const char * pairedDeviceSerializedSet); private: //////////// SecureSessionMgrDelegate Implementation /////////////// void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf, + SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) override; - void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override; + void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override; + void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) override; //////////// PersistentStorageResultDelegate Implementation /////////////// void OnValue(const char * key, const char * value) override; diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 4390c062cba0ba..a22a6d9cebce06 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -110,7 +110,7 @@ CHIP_ERROR ExchangeContext::SendMessage(uint16_t protocolId, uint8_t msgType, Pa payloadHeader.SetInitiator(IsInitiator()); - err = mExchangeMgr->GetSessionMgr()->SendMessage(payloadHeader, mPeerNodeId, std::move(msgBuf)); + err = mExchangeMgr->GetSessionMgr()->SendMessage(mSecureSession, payloadHeader, std::move(msgBuf)); SuccessOrExit(err); exit: @@ -191,7 +191,7 @@ void ExchangeContext::Reset() *this = ExchangeContext(); } -ExchangeContext * ExchangeContext::Alloc(ExchangeManager * em, uint16_t ExchangeId, uint64_t PeerNodeId, bool Initiator, +ExchangeContext * ExchangeContext::Alloc(ExchangeManager * em, uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, ExchangeDelegate * delegate) { VerifyOrDie(delegate != nullptr); @@ -201,8 +201,8 @@ ExchangeContext * ExchangeContext::Alloc(ExchangeManager * em, uint16_t Exchange Retain(); mExchangeMgr = em; em->IncrementContextsInUse(); - mExchangeId = ExchangeId; - mPeerNodeId = PeerNodeId; + mExchangeId = ExchangeId; + mSecureSession = session; mFlags.Set(ExFlagValues::kFlagInitiator, Initiator); mDelegate = delegate; @@ -236,7 +236,8 @@ void ExchangeContext::Free() SYSTEM_STATS_DECREMENT(chip::System::Stats::kExchangeMgr_NumContexts); } -bool ExchangeContext::MatchExchange(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader) +bool ExchangeContext::MatchExchange(SecureSessionHandle session, const PacketHeader & packetHeader, + const PayloadHeader & payloadHeader) { // A given message is part of a particular exchange if... return @@ -244,8 +245,8 @@ bool ExchangeContext::MatchExchange(const PacketHeader & packetHeader, const Pay // The exchange identifier of the message matches the exchange identifier of the context. (mExchangeId == payloadHeader.GetExchangeID()) - // AND The message was received from the peer node associated with the exchange, or the peer node identifier is 'any'. - && ((mPeerNodeId == kAnyNodeId) || (mPeerNodeId == packetHeader.GetSourceNodeId().Value())) + // AND The message was received from the peer node associated with the exchange + && (mSecureSession == session) // AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false) // OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index bd85fe28a5fbd6..cf1537c183f63c 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -129,7 +129,7 @@ class DLL_EXPORT ExchangeContext : public ReferenceCounted mFlags; // Internal state flags /** * Search for an existing exchange that the message applies to. * + * @param[in] session The secure session of the received message. + * * @param[in] packetHeader A reference to the PacketHeader object. * * @param[in] payloadHeader A reference to the PayloadHeader object. @@ -174,7 +176,7 @@ class DLL_EXPORT ExchangeContext : public ReferenceCounted 0 && ec.GetPeerNodeId() == peerNodeId && ec.GetDelegate() == delegate && - ec.IsInitiator() == isInitiator) - return &ec; - } - - return nullptr; + return AllocContext(mNextExchangeId++, session, true, delegate); } CHIP_ERROR ExchangeManager::RegisterUnsolicitedMessageHandler(uint32_t protocolId, ExchangeDelegate * delegate) @@ -141,8 +129,7 @@ void ExchangeManager::OnReceiveError(CHIP_ERROR error, const Transport::PeerAddr ChipLogError(ExchangeManager, "Accept FAILED, err = %s", ErrorStr(error)); } -ExchangeContext * ExchangeManager::AllocContext(uint16_t ExchangeId, uint64_t PeerNodeId, bool Initiator, - ExchangeDelegate * delegate) +ExchangeContext * ExchangeManager::AllocContext(uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, ExchangeDelegate * delegate) { CHIP_FAULT_INJECT(FaultInjection::kFault_AllocExchangeContext, return nullptr); @@ -150,7 +137,7 @@ ExchangeContext * ExchangeManager::AllocContext(uint16_t ExchangeId, uint64_t Pe { if (ec.GetReferenceCount() == 0) { - return ec.Alloc(this, ExchangeId, PeerNodeId, Initiator, delegate); + return ec.Alloc(this, ExchangeId, session, Initiator, delegate); } } @@ -158,7 +145,7 @@ ExchangeContext * ExchangeManager::AllocContext(uint16_t ExchangeId, uint64_t Pe return nullptr; } -void ExchangeManager::DispatchMessage(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, +void ExchangeManager::DispatchMessage(SecureSessionHandle session, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, System::PacketBufferHandle msgBuf) { UnsolicitedMessageHandler * umh = nullptr; @@ -168,7 +155,7 @@ void ExchangeManager::DispatchMessage(const PacketHeader & packetHeader, const P // Search for an existing exchange that the message applies to. If a match is found... for (auto & ec : ContextPool) { - if (ec.GetReferenceCount() > 0 && ec.MatchExchange(packetHeader, payloadHeader)) + if (ec.GetReferenceCount() > 0 && ec.MatchExchange(session, packetHeader, payloadHeader)) { // Matched ExchangeContext; send to message handler. ec.HandleMessage(packetHeader, payloadHeader, std::move(msgBuf)); @@ -213,7 +200,7 @@ void ExchangeManager::DispatchMessage(const PacketHeader & packetHeader, const P if (matchingUMH != nullptr) { auto * ec = - AllocContext(payloadHeader.GetExchangeID(), packetHeader.GetSourceNodeId().Value(), false, matchingUMH->Delegate); + AllocContext(payloadHeader.GetExchangeID(), session, false, matchingUMH->Delegate); VerifyOrExit(ec != nullptr, err = CHIP_ERROR_NO_MEMORY); ChipLogProgress(ExchangeManager, "ec pos: %d, id: %d, Delegate: 0x%x", ec - ContextPool.begin(), ec->GetExchangeId(), @@ -278,17 +265,17 @@ CHIP_ERROR ExchangeManager::UnregisterUMH(uint32_t protocolId, int16_t msgType) } void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf, + SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * msgLayer) { - DispatchMessage(packetHeader, payloadHeader, std::move(msgBuf)); + DispatchMessage(session, packetHeader, payloadHeader, std::move(msgBuf)); } -void ExchangeManager::OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) +void ExchangeManager::OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) { for (auto & ec : ContextPool) { - if (ec.GetReferenceCount() > 0 && ec.mPeerNodeId == state->GetPeerNodeId()) + if (ec.GetReferenceCount() > 0 && ec.mSecureSession == session) { ec.Close(); // Continue iterate because there can be multiple contexts associated with the connection. diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index d433d592ffabb5..64df641b458896 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -90,7 +90,7 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate * @return A pointer to the created ExchangeContext object On success. Otherwise NULL if no object * can be allocated or is available. */ - ExchangeContext * NewContext(const NodeId & peerNodeId, ExchangeDelegate * delegate); + ExchangeContext * NewContext(SecureSessionHandle session, ExchangeDelegate * delegate); /** * Find the ExchangeContext from a pool matching a given set of parameters. @@ -191,9 +191,9 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate UnsolicitedMessageHandler UMHandlerPool[CHIP_CONFIG_MAX_UNSOLICITED_MESSAGE_HANDLERS]; void (*OnExchangeContextChanged)(size_t numContextsInUse); - ExchangeContext * AllocContext(uint16_t ExchangeId, uint64_t PeerNodeId, bool Initiator, ExchangeDelegate * delegate); + ExchangeContext * AllocContext(uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, ExchangeDelegate * delegate); - void DispatchMessage(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, System::PacketBufferHandle msgBuf); + void DispatchMessage(SecureSessionHandle session, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, System::PacketBufferHandle msgBuf); CHIP_ERROR RegisterUMH(uint32_t protocolId, int16_t msgType, ExchangeDelegate * delegate); CHIP_ERROR UnregisterUMH(uint32_t protocolId, int16_t msgType); @@ -201,10 +201,10 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgr * msgLayer) override; void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf, + SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * msgLayer) override; - void OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override; + void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) override; }; } // namespace Messaging diff --git a/src/messaging/tests/TestExchangeMgr.cpp b/src/messaging/tests/TestExchangeMgr.cpp index 0e50f54a19d524..1a4ad0a4e45592 100644 --- a/src/messaging/tests/TestExchangeMgr.cpp +++ b/src/messaging/tests/TestExchangeMgr.cpp @@ -103,6 +103,27 @@ void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); } +class TestSessMgrCallback : public SecureSessionMgrDelegate +{ +public: + void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, SecureSessionHandle session, + System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) override + { + ReceiveHandlerCallCount++; + } + + void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override + { + mSecureSession = session; + NewConnectionHandlerCallCount++; + } + void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) override {} + + SecureSessionHandle mSecureSession; + int ReceiveHandlerCallCount = 0; + int NewConnectionHandlerCallCount = 0; +}; + void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); @@ -111,6 +132,23 @@ void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) SecureSessionMgr secureSessionMgr; CHIP_ERROR err; + TestSessMgrCallback callback; + secureSessionMgr.SetDelegate(&callback); + + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); + Optional peer1(Transport::PeerAddress::UDP(addr, 1)); + err = secureSessionMgr.NewPairing(peer1, kDestinationNodeId, &pairing1); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + SecureSessionHandle sessionFromSourceToDestination = callback.mSecureSession; + + SecurePairingUsingTestSecret pairing2(Optional::Value(kDestinationNodeId), 2, 1); + Optional peer2(Transport::PeerAddress::UDP(addr, 2)); + err = secureSessionMgr.NewPairing(peer2, kSourceNodeId, &pairing2); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + SecureSessionHandle sessionFromDestinationToSource = callback.mSecureSession; + ctx.GetInetLayer().SystemLayer()->Init(nullptr); err = transportMgr.Init("LOOPBACK"); @@ -124,48 +162,17 @@ void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) MockAppDelegate mockAppDelegate; - ExchangeContext * ec1 = exchangeMgr.NewContext(kSourceNodeId, &mockAppDelegate); + ExchangeContext * ec1 = exchangeMgr.NewContext(sessionFromDestinationToSource, &mockAppDelegate); NL_TEST_ASSERT(inSuite, ec1 != nullptr); NL_TEST_ASSERT(inSuite, ec1->IsInitiator() == true); NL_TEST_ASSERT(inSuite, ec1->GetExchangeId() != 0); - NL_TEST_ASSERT(inSuite, ec1->GetPeerNodeId() == kSourceNodeId); + NL_TEST_ASSERT(inSuite, ec1->GetSecureSession() == sessionFromDestinationToSource); NL_TEST_ASSERT(inSuite, ec1->GetDelegate() == &mockAppDelegate); - ExchangeContext * ec2 = exchangeMgr.NewContext(kDestinationNodeId, &mockAppDelegate); + ExchangeContext * ec2 = exchangeMgr.NewContext(sessionFromSourceToDestination, &mockAppDelegate); NL_TEST_ASSERT(inSuite, ec2 != nullptr); NL_TEST_ASSERT(inSuite, ec2->GetExchangeId() > ec1->GetExchangeId()); - NL_TEST_ASSERT(inSuite, ec2->GetPeerNodeId() == kDestinationNodeId); -} - -void CheckFindContextTest(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - TransportMgr transportMgr; - SecureSessionMgr secureSessionMgr; - CHIP_ERROR err; - - ctx.GetInetLayer().SystemLayer()->Init(nullptr); - - err = transportMgr.Init("LOOPBACK"); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - ExchangeManager exchangeMgr; - err = exchangeMgr.Init(&secureSessionMgr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - MockAppDelegate mockAppDelegate; - - ExchangeContext * ec = exchangeMgr.NewContext(kDestinationNodeId, &mockAppDelegate); - NL_TEST_ASSERT(inSuite, ec != nullptr); - - bool result = exchangeMgr.FindContext(kDestinationNodeId, &mockAppDelegate, true); - NL_TEST_ASSERT(inSuite, result == true); - - result = exchangeMgr.FindContext(kDestinationNodeId, nullptr, false); - NL_TEST_ASSERT(inSuite, result == false); + NL_TEST_ASSERT(inSuite, ec2->GetSecureSession() == sessionFromSourceToDestination); } void CheckUmhRegistrationTest(nlTestSuite * inSuite, void * inContext) @@ -211,12 +218,26 @@ void CheckUmhRegistrationTest(nlTestSuite * inSuite, void * inContext) void CheckExchangeMessages(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - CHIP_ERROR err; TransportMgr transportMgr; SecureSessionMgr secureSessionMgr; + CHIP_ERROR err; + + TestSessMgrCallback callback; + secureSessionMgr.SetDelegate(&callback); + IPAddress addr; IPAddress::FromString("127.0.0.1", addr); + SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); + Optional peer1(Transport::PeerAddress::UDP(addr, 1)); + err = secureSessionMgr.NewPairing(peer1, kDestinationNodeId, &pairing1); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + SecureSessionHandle sessionFromSourceToDestination = callback.mSecureSession; + + SecurePairingUsingTestSecret pairing2(Optional::Value(kDestinationNodeId), 2, 1); + Optional peer2(Transport::PeerAddress::UDP(addr, 2)); + err = secureSessionMgr.NewPairing(peer2, kSourceNodeId, &pairing2); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ctx.GetInetLayer().SystemLayer()->Init(nullptr); @@ -229,18 +250,9 @@ void CheckExchangeMessages(nlTestSuite * inSuite, void * inContext) err = exchangeMgr.Init(&secureSessionMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); - Optional peer1(Transport::PeerAddress::UDP(addr, 1)); - err = secureSessionMgr.NewPairing(peer1, kSourceNodeId, &pairing1); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret pairing2(Optional::Value(kDestinationNodeId), 2, 1); - Optional peer2(Transport::PeerAddress::UDP(addr, 2)); - err = secureSessionMgr.NewPairing(peer2, kDestinationNodeId, &pairing2); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - // create solicited exchange MockAppDelegate mockSolicitedAppDelegate; - ExchangeContext * ec1 = exchangeMgr.NewContext(kDestinationNodeId, &mockSolicitedAppDelegate); + ExchangeContext * ec1 = exchangeMgr.NewContext(sessionFromSourceToDestination, &mockSolicitedAppDelegate); // create unsolicited exchange MockAppDelegate mockUnsolicitedAppDelegate; @@ -265,7 +277,6 @@ const nlTest sTests[] = { NL_TEST_DEF("Test ExchangeMgr::Init", CheckSimpleInitTest), NL_TEST_DEF("Test ExchangeMgr::NewContext", CheckNewContextTest), - NL_TEST_DEF("Test ExchangeMgr::FindContext", CheckFindContextTest), NL_TEST_DEF("Test ExchangeMgr::CheckUmhRegistrationTest", CheckUmhRegistrationTest), NL_TEST_DEF("Test ExchangeMgr::CheckExchangeMessages", CheckExchangeMessages), diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index 96a26376667947..31bbc73ef8dbfe 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -101,7 +101,7 @@ CHIP_ERROR SendEchoRequest(void) printf("\nSend echo request message to Node: %" PRIu64 "\n", chip::kTestDeviceNodeId); - err = gEchoClient.SendEchoRequest(chip::kTestDeviceNodeId, std::move(payloadBuf)); + err = gEchoClient.SendEchoRequest(std::move(payloadBuf)); if (err == CHIP_NO_ERROR) { @@ -143,7 +143,7 @@ CHIP_ERROR EstablishSecureSession() return err; } -void HandleEchoResponseReceived(chip::NodeId nodeId, chip::System::PacketBufferHandle payload) +void HandleEchoResponseReceived(chip::Messaging::ExchangeContext * ec, chip::System::PacketBufferHandle payload) { uint32_t respTime = chip::System::Timer::GetCurrentEpoch(); uint32_t transitTime = respTime - gLastEchoTime; @@ -151,11 +151,21 @@ void HandleEchoResponseReceived(chip::NodeId nodeId, chip::System::PacketBufferH gWaitingForEchoResp = false; gEchoRespCount++; - printf("Echo Response from node %" PRIu64 ": %" PRIu64 "/%" PRIu64 "(%.2f%%) len=%u time=%.3fms\n", nodeId, gEchoRespCount, + printf("Echo Response: %" PRIu64 "/%" PRIu64 "(%.2f%%) len=%u time=%.3fms\n", gEchoRespCount, gEchoCount, static_cast(gEchoRespCount) * 100 / gEchoCount, payload->DataLength(), static_cast(transitTime) / 1000); } +class TestSecureSessionMgrDelegate : public chip::SecureSessionMgrDelegate +{ +public: + void OnNewConnection(chip::SecureSessionHandle session, chip::SecureSessionMgr * mgr) override { + mSecureSession = session; + } + + chip::SecureSessionHandle mSecureSession; +} gTestSecureSessionMgrDelegate; + } // namespace int main(int argc, char * argv[]) @@ -184,19 +194,21 @@ int main(int argc, char * argv[]) err = gSessionManager.Init(chip::kTestControllerNodeId, &chip::DeviceLayer::SystemLayer, &gTransportManager); SuccessOrExit(err); + gSessionManager.SetDelegate(&gTestSecureSessionMgrDelegate); + err = gExchangeManager.Init(&gSessionManager); SuccessOrExit(err); - err = gEchoClient.Init(&gExchangeManager); + // Start the CHIP connection to the CHIP echo responder. + err = EstablishSecureSession(); + SuccessOrExit(err); + + err = gEchoClient.Init(&gExchangeManager, gTestSecureSessionMgrDelegate.mSecureSession); SuccessOrExit(err); // Arrange to get a callback whenever an Echo Response is received. gEchoClient.SetEchoResponseReceived(HandleEchoResponseReceived); - // Start the CHIP connection to the CHIP echo responder. - err = EstablishSecureSession(); - SuccessOrExit(err); - // Connection has been established. Now send the EchoRequests. for (unsigned int i = 0; i < kMaxEchoCount; i++) { diff --git a/src/messaging/tests/echo/echo_responder.cpp b/src/messaging/tests/echo/echo_responder.cpp index afc4c4c23263fd..b7f9e8a2249edd 100644 --- a/src/messaging/tests/echo/echo_responder.cpp +++ b/src/messaging/tests/echo/echo_responder.cpp @@ -47,9 +47,9 @@ chip::SecureSessionMgr gSessionManager; chip::SecurePairingUsingTestSecret gTestPairing; // Callback handler when a CHIP EchoRequest is received. -void HandleEchoRequestReceived(chip::NodeId nodeId, chip::System::PacketBufferHandle payload) +void HandleEchoRequestReceived(chip::Messaging::ExchangeContext * ec, chip::System::PacketBufferHandle payload) { - printf("Echo Request from node %" PRIu64 ", len=%u ... sending response.\n", nodeId, payload->DataLength()); + printf("Echo Request, len=%u ... sending response.\n", payload->DataLength()); } } // namespace diff --git a/src/protocols/echo/Echo.h b/src/protocols/echo/Echo.h index fa6d41b61abb54..695c1d9ebe0f7d 100644 --- a/src/protocols/echo/Echo.h +++ b/src/protocols/echo/Echo.h @@ -43,11 +43,12 @@ enum kEchoMessageType_EchoResponse = 2 }; -typedef void (*EchoFunct)(NodeId nodeId, System::PacketBufferHandle payload); +using EchoFunct = void (*)(Messaging::ExchangeContext * ec, System::PacketBufferHandle payload); class DLL_EXPORT EchoClient : public Messaging::ExchangeDelegate { public: + // TODO: Init function will take a Channel instead a SecureSessionHandle, when Channel API is ready /** * Initialize the EchoClient object. Within the lifetime * of this instance, this method is invoked once after object @@ -55,13 +56,14 @@ class DLL_EXPORT EchoClient : public Messaging::ExchangeDelegate * instance. * * @param[in] exchangeMgr A pointer to the ExchangeManager object. + * @param[in] sessoin A handle to the session. * * @retval #CHIP_ERROR_INCORRECT_STATE If the state is not equal to * kState_NotInitialized. * @retval #CHIP_NO_ERROR On success. * */ - CHIP_ERROR Init(Messaging::ExchangeManager * exchangeMgr); + CHIP_ERROR Init(Messaging::ExchangeManager * exchangeMgr, SecureSessionHandle session); /** * Shutdown the EchoClient. This terminates this instance @@ -88,14 +90,14 @@ class DLL_EXPORT EchoClient : public Messaging::ExchangeDelegate * Other CHIP_ERROR codes as returned by the lower layers. * */ - CHIP_ERROR SendEchoRequest(NodeId nodeId, System::PacketBufferHandle payload); + CHIP_ERROR SendEchoRequest(System::PacketBufferHandle payload); private: Messaging::ExchangeManager * mExchangeMgr = nullptr; Messaging::ExchangeContext * mExchangeCtx = nullptr; EchoFunct OnEchoResponseReceived = nullptr; + SecureSessionHandle mSecureSession; - CHIP_ERROR SendEchoRequest(System::PacketBufferHandle payload); void OnMessageReceived(Messaging::ExchangeContext * ec, const PacketHeader & packetHeader, uint32_t protocolId, uint8_t msgType, System::PacketBufferHandle payload) override; void OnResponseTimeout(Messaging::ExchangeContext * ec) override; diff --git a/src/protocols/echo/EchoClient.cpp b/src/protocols/echo/EchoClient.cpp index a6389f41699139..561b51755be95c 100644 --- a/src/protocols/echo/EchoClient.cpp +++ b/src/protocols/echo/EchoClient.cpp @@ -28,13 +28,14 @@ namespace chip { namespace Protocols { -CHIP_ERROR EchoClient::Init(Messaging::ExchangeManager * exchangeMgr) +CHIP_ERROR EchoClient::Init(Messaging::ExchangeManager * exchangeMgr, SecureSessionHandle session) { // Error if already initialized. if (mExchangeMgr != nullptr) return CHIP_ERROR_INCORRECT_STATE; mExchangeMgr = exchangeMgr; + mSecureSession = session; OnEchoResponseReceived = nullptr; mExchangeCtx = nullptr; @@ -50,8 +51,10 @@ void EchoClient::Shutdown() } } -CHIP_ERROR EchoClient::SendEchoRequest(NodeId nodeId, System::PacketBufferHandle payload) +CHIP_ERROR EchoClient::SendEchoRequest(System::PacketBufferHandle payload) { + CHIP_ERROR err = CHIP_NO_ERROR; + // Discard any existing exchange context. Effectively we can only have one Echo exchange with // a single node at any one time. if (mExchangeCtx != nullptr) @@ -61,19 +64,12 @@ CHIP_ERROR EchoClient::SendEchoRequest(NodeId nodeId, System::PacketBufferHandle } // Create a new exchange context. - mExchangeCtx = mExchangeMgr->NewContext(nodeId, this); + mExchangeCtx = mExchangeMgr->NewContext(mSecureSession, this); if (mExchangeCtx == nullptr) { return CHIP_ERROR_NO_MEMORY; } - return SendEchoRequest(std::move(payload)); -} - -CHIP_ERROR EchoClient::SendEchoRequest(System::PacketBufferHandle payload) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - // Send an Echo Request message. Discard the exchange context if the send fails. err = mExchangeCtx->SendMessage(kProtocol_Echo, kEchoMessageType_EchoRequest, std::move(payload), Messaging::SendFlags(Messaging::SendMessageFlags::kSendFlag_None)); @@ -115,13 +111,13 @@ void EchoClient::OnMessageReceived(Messaging::ExchangeContext * ec, const Packet // Call the registered OnEchoResponseReceived handler, if any. if (OnEchoResponseReceived != nullptr) { - OnEchoResponseReceived(packetHeader.GetSourceNodeId().ValueOr(0), std::move(payload)); + OnEchoResponseReceived(ec, std::move(payload)); } } void EchoClient::OnResponseTimeout(Messaging::ExchangeContext * ec) { - ChipLogProgress(Echo, "Time out! failed to receive echo response from Node: %llu", ec->GetPeerNodeId()); + ChipLogProgress(Echo, "Time out! failed to receive echo response from Exchange: %p", ec); } } // namespace Protocols diff --git a/src/protocols/echo/EchoServer.cpp b/src/protocols/echo/EchoServer.cpp index dedc22b9d1981a..95584037e93db2 100644 --- a/src/protocols/echo/EchoServer.cpp +++ b/src/protocols/echo/EchoServer.cpp @@ -64,7 +64,7 @@ void EchoServer::OnMessageReceived(Messaging::ExchangeContext * ec, const Packet if (OnEchoRequestReceived != nullptr) { response = payload.Retain(); - OnEchoRequestReceived(ec->GetPeerNodeId(), std::move(payload)); + OnEchoRequestReceived(ec, std::move(payload)); } else { diff --git a/src/transport/PeerConnections.h b/src/transport/PeerConnections.h index 439b208d181e9a..95938da5ff603e 100644 --- a/src/transport/PeerConnections.h +++ b/src/transport/PeerConnections.h @@ -24,6 +24,11 @@ namespace chip { namespace Transport { +// TODO; use 0xffff to match any key id, this is a temporary solution for +// InteractionModel, where key id is not obtainable. This will be removed when +// InteractionModel is migrated to messaging layer +constexpr const uint16_t kAnyKeyId = 0xffff; + /** * Handles a set of peer connection states. * @@ -219,7 +224,7 @@ class PeerConnections { continue; } - if (iter->GetPeerKeyID() == peerKeyId) + if (peerKeyId == kAnyKeyId || iter->GetPeerKeyID() == peerKeyId) { if (!nodeId.HasValue() || iter->GetPeerNodeId() == kUndefinedNodeId || iter->GetPeerNodeId() == nodeId.Value()) { diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 35f13b18f46f5f..2b2fc6a3cf0272 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -85,17 +85,17 @@ CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLaye return err; } -CHIP_ERROR SecureSessionMgr::SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf) +CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, System::PacketBufferHandle msgBuf) { PayloadHeader payloadHeader; - return SendMessage(payloadHeader, peerNodeId, std::move(msgBuf)); + return SendMessage(session, payloadHeader, std::move(msgBuf)); } -CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf) +CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle msgBuf) { CHIP_ERROR err = CHIP_NO_ERROR; - PeerConnectionState * state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr); + PeerConnectionState * state = GetPeerConnectionState(session); VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE); @@ -125,15 +125,14 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p payloadLength = static_cast(headerSize + msgBuf->TotalLength()); VerifyOrExit(CanCastTo(payloadLength), err = CHIP_ERROR_NO_MEMORY); - packetHeader - .SetSourceNodeId(mLocalNodeId) // - .SetDestinationNodeId(peerNodeId) // - .SetMessageId(state->GetSendMessageIndex()) // - .SetEncryptionKeyID(state->GetLocalKeyID()) // + packetHeader.SetSourceNodeId(mLocalNodeId) + .SetDestinationNodeId(state->GetPeerNodeId()) + .SetMessageId(state->GetSendMessageIndex()) + .SetEncryptionKeyID(state->GetLocalKeyID()) .SetPayloadLength(static_cast(payloadLength)); packetHeader.GetFlags().Set(Header::FlagValues::kSecure); - ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId); + ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, state->GetPeerNodeId()); VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY); @@ -219,6 +218,10 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional & { err = pairing->DeriveSecureSession(reinterpret_cast(kSpake2pI2RSessionInfo), strlen(kSpake2pI2RSessionInfo), state->GetSecureSession()); + if (mCB != nullptr) + { + mCB->OnNewConnection({ state->GetPeerNodeId(), state->GetPeerKeyID() }, this); + } } exit: @@ -347,7 +350,7 @@ void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, cons if (mCB != nullptr) { - mCB->OnMessageReceived(packetHeader, payloadHeader, state, std::move(msg), this); + mCB->OnMessageReceived(packetHeader, payloadHeader, { state->GetPeerNodeId(), state->GetPeerKeyID() }, std::move(msg), this); } } @@ -367,7 +370,7 @@ void SecureSessionMgr::HandleConnectionExpired(const Transport::PeerConnectionSt if (mCB != nullptr) { - mCB->OnConnectionExpired(&state, this); + mCB->OnConnectionExpired({ state.GetPeerNodeId(), state.GetPeerKeyID() }, this); } mTransportMgr->Disconnect(state.GetPeerAddress()); @@ -386,4 +389,9 @@ void SecureSessionMgr::ExpiryTimerCallback(System::Layer * layer, void * param, mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer } +PeerConnectionState * SecureSessionMgr::GetPeerConnectionState(SecureSessionHandle session) +{ + return mPeerConnections.FindPeerConnectionState(Optional::Value(session.mPeerNodeId), session.mPeerKeyId, nullptr); +} + } // namespace chip diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index ec7400e083be89..0327391949deb6 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -44,6 +44,23 @@ namespace chip { class SecureSessionMgr; + +class SecureSessionHandle +{ +public: + SecureSessionHandle() : mPeerNodeId(kAnyNodeId), mPeerKeyId(0) {} + SecureSessionHandle(NodeId peerNodeId, uint16_t peerKeyId) : mPeerNodeId(peerNodeId), mPeerKeyId(peerKeyId) {} + + bool operator==(const SecureSessionHandle & that) const + { + return mPeerNodeId == that.mPeerNodeId && mPeerKeyId == that.mPeerKeyId; + } +private: + friend class SecureSessionMgr; + NodeId mPeerNodeId; + uint16_t mPeerKeyId; +}; + /** * @brief * This class provides a skeleton for the callback functions. The functions will be @@ -61,12 +78,12 @@ class DLL_EXPORT SecureSessionMgrDelegate * * @param packetHeader The message header * @param payloadHeader The payload header - * @param state The connection state + * @param session The handle to the secure session * @param msgBuf The received message * @param mgr A pointer to the SecureSessionMgr */ virtual void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf, + SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) {} @@ -84,19 +101,19 @@ class DLL_EXPORT SecureSessionMgrDelegate * @brief * Called when a new pairing is being established * - * @param state connection state + * @param session The handle to the secure session * @param mgr A pointer to the SecureSessionMgr */ - virtual void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) {} + virtual void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) {} /** * @brief * Called when a new connection is closing * - * @param state connection state + * @param session The handle to the secure session * @param mgr A pointer to the SecureSessionMgr */ - virtual void OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) {} + virtual void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) {} /** * @brief @@ -118,8 +135,9 @@ class DLL_EXPORT SecureSessionMgr : public Mdns::ResolveDelegate, public Transpo * @brief * Send a message to a currently connected peer */ - CHIP_ERROR SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf); - CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf); + CHIP_ERROR SendMessage(SecureSessionHandle session, System::PacketBufferHandle msgBuf); + CHIP_ERROR SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle msgBuf); + Transport::PeerConnectionState * GetPeerConnectionState(SecureSessionHandle session); SecureSessionMgr(); ~SecureSessionMgr() override; diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 7c72b78156e8e9..4af0836243a3cb 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -67,12 +67,12 @@ class LoopbackTransport : public Transport::Base class TestSessMgrCallback : public SecureSessionMgrDelegate { public: - void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const PeerConnectionState * state, + void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, SecureSessionHandle session, System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) override { NL_TEST_ASSERT(mSuite, header.GetSourceNodeId() == Optional::Value(kSourceNodeId)); NL_TEST_ASSERT(mSuite, header.GetDestinationNodeId() == Optional::Value(kDestinationNodeId)); - NL_TEST_ASSERT(mSuite, state->GetPeerNodeId() == kSourceNodeId); + NL_TEST_ASSERT(mSuite, session == mRemoteToLocalSession); // Packet received by remote peer size_t data_len = msgBuf->DataLength(); @@ -82,9 +82,19 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate ReceiveHandlerCallCount++; } - void OnNewConnection(const PeerConnectionState * state, SecureSessionMgr * mgr) override { NewConnectionHandlerCallCount++; } + void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override + { + if (NewConnectionHandlerCallCount == 0) + mRemoteToLocalSession = session; + if (NewConnectionHandlerCallCount == 1) + mLocalToRemoteSession = session; + NewConnectionHandlerCallCount++; + } + void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) override {} - nlTestSuite * mSuite = nullptr; + nlTestSuite * mSuite = nullptr; + SecureSessionHandle mRemoteToLocalSession; + SecureSessionHandle mLocalToRemoteSession; int ReceiveHandlerCallCount = 0; int NewConnectionHandlerCallCount = 0; }; @@ -137,20 +147,22 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) secureSessionMgr.SetDelegate(&callback); - SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); - err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing1); + SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); + err = secureSessionMgr.NewPairing(peer, kSourceNodeId, &pairing1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); SecurePairingUsingTestSecret pairing2(Optional::Value(kDestinationNodeId), 2, 1); - err = secureSessionMgr.NewPairing(peer, kSourceNodeId, &pairing2); + err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + SecureSessionHandle localToRemoteSessoin = callback.mLocalToRemoteSession; + // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; - err = secureSessionMgr.SendMessage(kDestinationNodeId, std::move(buffer)); + err = secureSessionMgr.SendMessage(localToRemoteSessoin, std::move(buffer)); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 0; });