diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 04138f57d0b568..3481e9bf7a6530 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -295,6 +295,19 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const DispatchMessage(packetHeader, payloadHeader, msgBuf); } +void ExchangeManager::OnConnectionExpired(Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) +{ + // Search for an existing exchange that the message applies to. If a match is found... + ExchangeContext * ec = ContextPool; + for (int i = 0; i < CHIP_CONFIG_MAX_EXCHANGE_CONTEXTS; i++, ec++) + { + if (ec->GetReferenceCount() > 0 && ec->mPeerNodeId == state->GetPeerNodeId()) + { + ec->Close(); + } + } +} + void ExchangeManager::IncrementContextsInUse() { mContextsInUse++; diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index 2897065844aebc..ed38d4c141a4fa 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -201,6 +201,8 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, SecureSessionMgrBase * msgLayer) override; + + void OnConnectionExpired(Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) override; }; } // namespace chip diff --git a/src/transport/PeerConnections.h b/src/transport/PeerConnections.h index f69a4c1f570f2d..79e42d084b603d 100644 --- a/src/transport/PeerConnections.h +++ b/src/transport/PeerConnections.h @@ -326,7 +326,7 @@ class PeerConnections * */ template - void SetConnectionExpiredHandler(void (*handler)(const PeerConnectionState &, T *), T * param) + void SetConnectionExpiredHandler(void (*handler)(PeerConnectionState &, T *), T * param) { mConnectionExpiredArgument = param; OnConnectionExpired = reinterpret_cast(handler); diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 8f944dddadedfc..253224557260c3 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -355,13 +355,18 @@ void SecureSessionMgrBase::HandleDataReceived(const PacketHeader & packetHeader, } } -void SecureSessionMgrBase::HandleConnectionExpired(const Transport::PeerConnectionState & state, SecureSessionMgrBase * mgr) +void SecureSessionMgrBase::HandleConnectionExpired(Transport::PeerConnectionState & state, SecureSessionMgrBase * mgr) { char addr[Transport::PeerAddress::kMaxToStringSize]; state.GetPeerAddress().ToString(addr, sizeof(addr)); ChipLogDetail(Inet, "Connection from '%s' expired", addr); + if (mgr->mCB != nullptr) + { + mgr->mCB->OnConnectionExpired(&state, mgr); + } + mgr->mTransport->Disconnect(state.GetPeerAddress()); } diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index e3304b693e7fc8..de1f27dde56d12 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -88,6 +88,15 @@ class DLL_EXPORT SecureSessionMgrDelegate */ virtual void OnNewConnection(Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) {} + /** + * @brief + * Called when a new connection is closing + * + * @param state connection state + * @param mgr A pointer to the SecureSessionMgr + */ + virtual void OnConnectionExpired(Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) {} + /** * @brief * Called when the peer address is resolved from NodeID. @@ -184,7 +193,7 @@ class DLL_EXPORT SecureSessionMgrBase : public Mdns::ResolveDelegate /** * Called when a specific connection expires. */ - static void HandleConnectionExpired(const Transport::PeerConnectionState & state, SecureSessionMgrBase * mgr); + static void HandleConnectionExpired(Transport::PeerConnectionState & state, SecureSessionMgrBase * mgr); /** * Callback for timer expiry check diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp index 0601abbd19a7c6..634be48777a567 100644 --- a/src/transport/tests/TestPeerConnections.cpp +++ b/src/transport/tests/TestPeerConnections.cpp @@ -201,7 +201,7 @@ struct ExpiredCallInfo PeerAddress lastCallPeerAddress = PeerAddress::Uninitialized(); }; -void OnConnectionExpired(const PeerConnectionState & state, ExpiredCallInfo * info) +void OnConnectionExpired(PeerConnectionState & state, ExpiredCallInfo * info) { info->callCount++; info->lastCallNodeId = state.GetPeerNodeId();