diff --git a/examples/shell/shell_common/include/Globals.h b/examples/shell/shell_common/include/Globals.h index d91620ca0ad171..69162be92f5204 100644 --- a/examples/shell/shell_common/include/Globals.h +++ b/examples/shell/shell_common/include/Globals.h @@ -24,7 +24,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include #if INET_CONFIG_ENABLE_TCP_ENDPOINT diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index 382d2620fcfa6a..dd9a75ada0ac08 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -30,62 +30,57 @@ CHIP_ERROR CASESessionManager::Init(chip::System::Layer * systemLayer, const CAS } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onFailure + Callback::Callback * onFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES -) + TransportPayloadCapability transportPayloadCapability) { FindOrEstablishSessionHelper(peerId, onConnection, onFailure, nullptr #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES , attemptCount, onRetry #endif - ); + , + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, nullptr, onSetupFailure + FindOrEstablishSessionHelper(peerId, onConnection, nullptr, onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - nullptr_t + nullptr_t, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, nullptr, nullptr + FindOrEstablishSessionHelper(peerId, onConnection, nullptr, nullptr, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSessionHelper(const ScopedNodeId & peerId, Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { ChipLogDetail(CASESessionManager, "FindOrEstablishSession: PeerId = [%d:" ChipLogFormatX64 "]", peerId.GetFabricIndex(), ChipLogValueX64(peerId.GetNodeId())); @@ -124,12 +119,12 @@ void CASESessionManager::FindOrEstablishSessionHelper(const ScopedNodeId & peerI if (onFailure != nullptr) { - session->Connect(onConnection, onFailure); + session->Connect(onConnection, onFailure, transportPayloadCapability); } if (onSetupFailure != nullptr) { - session->Connect(onConnection, onSetupFailure); + session->Connect(onConnection, onSetupFailure, transportPayloadCapability); } } @@ -143,10 +138,11 @@ void CASESessionManager::ReleaseAllSessions() mConfig.sessionSetupPool->ReleaseAllSessionSetup(); } -CHIP_ERROR CASESessionManager::GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr) +CHIP_ERROR CASESessionManager::GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr, + TransportPayloadCapability transportPayloadCapability) { ReturnErrorOnFailure(mConfig.sessionInitParams.Validate()); - auto optionalSessionHandle = FindExistingSession(peerId); + auto optionalSessionHandle = FindExistingSession(peerId, transportPayloadCapability); ReturnErrorCodeIf(!optionalSessionHandle.HasValue(), CHIP_ERROR_NOT_CONNECTED); addr = optionalSessionHandle.Value()->AsSecureSession()->GetPeerAddress(); return CHIP_NO_ERROR; @@ -182,10 +178,11 @@ OperationalSessionSetup * CASESessionManager::FindExistingSessionSetup(const Sco return mConfig.sessionSetupPool->FindSessionSetup(peerId, forAddressUpdate); } -Optional CASESessionManager::FindExistingSession(const ScopedNodeId & peerId) const +Optional CASESessionManager::FindExistingSession(const ScopedNodeId & peerId, + const TransportPayloadCapability transportPayloadCapability) const { - return mConfig.sessionInitParams.sessionManager->FindSecureSessionForNode(peerId, - MakeOptional(Transport::SecureSession::Type::kCASE)); + return mConfig.sessionInitParams.sessionManager->FindSecureSessionForNode( + peerId, MakeOptional(Transport::SecureSession::Type::kCASE), transportPayloadCapability); } void CASESessionManager::ReleaseSession(OperationalSessionSetup * session) diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h index 3e668fa905cbdf..ac193bdc44ba70 100644 --- a/src/app/CASESessionManager.h +++ b/src/app/CASESessionManager.h @@ -26,6 +26,7 @@ #include #include #include +#include #include namespace chip { @@ -83,7 +84,8 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess , uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + , + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /** * Find an existing session for the given node ID or trigger a new session request. @@ -106,6 +108,7 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * @param onSetupFailure A callback to be called upon an extended device connection failure. * @param attemptCount The number of retry attempts if session setup fails (default is 1). * @param onRetry A callback to be called on a retry attempt (enabled by a config flag). + * @param transportPayloadCapability An indicator of what payload types the session needs to be able to transport. */ void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, Callback::Callback * onSetupFailure @@ -113,7 +116,8 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess , uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + , + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /** * Find an existing session for the given node ID or trigger a new session request. @@ -134,13 +138,15 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * @param onConnection A callback to be called upon successful connection establishment. * @param attemptCount The number of retry attempts if session setup fails (default is 1). * @param onRetry A callback to be called on a retry attempt (enabled by a config flag). + * @param transportPayloadCapability An indicator of what payload types the session needs to be able to transport. */ void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, nullptr_t #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES , uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + , + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); void ReleaseSessionsForFabric(FabricIndex fabricIndex); @@ -154,7 +160,8 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * an ongoing session with the peer node. If the session doesn't exist, the API will return * `CHIP_ERROR_NOT_CONNECTED` error. */ - CHIP_ERROR GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr); + CHIP_ERROR GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); //////////// OperationalSessionReleaseDelegate Implementation /////////////// void ReleaseSession(OperationalSessionSetup * device) override; @@ -165,15 +172,17 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess private: OperationalSessionSetup * FindExistingSessionSetup(const ScopedNodeId & peerId, bool forAddressUpdate = false) const; - Optional FindExistingSession(const ScopedNodeId & peerId) const; + Optional FindExistingSession( + const ScopedNodeId & peerId, + const TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload) const; void FindOrEstablishSessionHelper(const ScopedNodeId & peerId, Callback::Callback * onConnection, Callback::Callback * onFailure, Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif - ); + TransportPayloadCapability transportPayloadCapability); CASESessionManagerConfig mConfig; }; diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 6df07ca6b0568e..15eae094fc92d0 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -76,8 +76,8 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() mState == State::WaitingForRetry, false); - auto sessionHandle = - mInitParams.sessionManager->FindSecureSessionForNode(mPeerId, MakeOptional(Transport::SecureSession::Type::kCASE)); + auto sessionHandle = mInitParams.sessionManager->FindSecureSessionForNode( + mPeerId, MakeOptional(Transport::SecureSession::Type::kCASE), mTransportPayloadCapability); if (!sessionHandle.HasValue()) return false; @@ -93,11 +93,13 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() void OperationalSessionSetup::Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure) + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability) { CHIP_ERROR err = CHIP_NO_ERROR; bool isConnected = false; + mTransportPayloadCapability = transportPayloadCapability; // // Always enqueue our user provided callbacks into our callback list. // If anything goes wrong below, we'll trigger failures (including any queued from @@ -180,15 +182,17 @@ void OperationalSessionSetup::Connect(Callback::Callback * on } void OperationalSessionSetup::Connect(Callback::Callback * onConnection, - Callback::Callback * onFailure) + Callback::Callback * onFailure, + TransportPayloadCapability transportPayloadCapability) { - Connect(onConnection, onFailure, nullptr); + Connect(onConnection, onFailure, nullptr, transportPayloadCapability); } void OperationalSessionSetup::Connect(Callback::Callback * onConnection, - Callback::Callback * onSetupFailure) + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability) { - Connect(onConnection, nullptr, onSetupFailure); + Connect(onConnection, nullptr, onSetupFailure, transportPayloadCapability); } void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & addr, const ReliableMessageProtocolConfig & config) @@ -288,6 +292,16 @@ void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & ad CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessageProtocolConfig & config) { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // TODO: Combine LargePayload flag with DNS-SD advertisements from peer. + // Issue #32348. + if (mTransportPayloadCapability == TransportPayloadCapability::kLargePayload) + { + // Set the transport type for carrying large payloads + mDeviceAddress.SetTransportType(chip::Transport::Type::kTcp); + } +#endif + mCASEClient = mClientPool->Allocate(); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index 9d24faad5efabb..7810fb8709b610 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -202,8 +202,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * `onFailure` may be called before the Connect call returns, for error * cases that are detected synchronously (e.g. inability to start an address * lookup). + * + * `transportPayloadCapability` is set to kLargePayload when the session needs to be established + * over a transport that allows large payloads to be transferred, e.g., TCP. */ - void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure); + void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /* * This function can be called to establish a secure session with the device. @@ -219,8 +223,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * * `onSetupFailure` may be called before the Connect call returns, for error cases that are detected synchronously * (e.g. inability to start an address lookup). + * + * `transportPayloadCapability` is set to kLargePayload when the session needs to be established + * over a transport that allows large payloads to be transferred, e.g., TCP. */ - void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure); + void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); bool IsForAddressUpdate() const { return mPerformingAddressUpdate; } @@ -306,6 +314,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, bool mPerformingAddressUpdate = false; + TransportPayloadCapability mTransportPayloadCapability = TransportPayloadCapability::kMRPPayload; + #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES // When we TryNextResult on the resolver, it will synchronously call back // into our OnNodeAddressResolved when it succeeds. We need to track @@ -341,7 +351,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, void CleanupCASEClient(); void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure); + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); void EnqueueConnectionCallbacks(Callback::Callback * onConnection, Callback::Callback * onFailure, diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index d2983636450691..693bd235bbecad 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -71,6 +71,9 @@ using chip::Transport::BleListenParameters; #endif using chip::Transport::PeerAddress; using chip::Transport::UdpListenParameters; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +using chip::Transport::TcpListenParameters; +#endif namespace { @@ -201,6 +204,12 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) #if CONFIG_NETWORK_LAYER_BLE , BleListenParameters(DeviceLayer::ConnectivityMgr().GetBleLayer()) +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + TcpListenParameters(DeviceLayer::TCPEndPointManager()) + .SetAddressType(IPAddressType::kIPv6) + .SetListenPort(mOperationalServicePort) #endif ); diff --git a/src/app/server/Server.h b/src/app/server/Server.h index 302471b8bebb89..fdd9008165c570 100644 --- a/src/app/server/Server.h +++ b/src/app/server/Server.h @@ -77,6 +77,12 @@ namespace chip { inline constexpr size_t kMaxBlePendingPackets = 1; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +inline constexpr size_t kMaxTcpActiveConnectionCount = CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS; + +inline constexpr size_t kMaxTcpPendingPackets = CHIP_CONFIG_MAX_TCP_PENDING_PACKETS; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // // NOTE: Please do not alter the order of template specialization here as the logic // in the Server impl depends on this. @@ -89,6 +95,10 @@ using ServerTransportMgr = chip::TransportMgr +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + chip::Transport::TCP #endif >; diff --git a/src/app/tests/TestCommissionManager.cpp b/src/app/tests/TestCommissionManager.cpp index a12ec42ec3c49a..82c4d23a349182 100644 --- a/src/app/tests/TestCommissionManager.cpp +++ b/src/app/tests/TestCommissionManager.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -99,6 +100,9 @@ void InitializeChip(nlTestSuite * suite) static chip::app::reporting::ReportSchedulerImpl sReportScheduler(&sTimerDelegate); initParams.reportScheduler = &sReportScheduler; (void) initParams.InitializeStaticResourcesBeforeServerInit(); + // Set a randomized server port(slightly shifted from CHIP_PORT) for testing + initParams.operationalServicePort = static_cast(initParams.operationalServicePort + chip::Crypto::GetRandU16() % 20); + err = chip::Server::GetInstance().Init(initParams); NL_TEST_ASSERT(suite, err == CHIP_NO_ERROR); diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index 25bcd12b190a09..5af0424265dc2f 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -164,6 +164,12 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) #if CONFIG_NETWORK_LAYER_BLE , Transport::BleListenParameters(stateParams.bleLayer) +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + Transport::TcpListenParameters(stateParams.tcpEndPointManager) + .SetAddressType(IPAddressType::kIPv6) + .SetListenPort(params.listenPort) #endif )); diff --git a/src/controller/CHIPDeviceControllerSystemState.h b/src/controller/CHIPDeviceControllerSystemState.h index bc89d6c1d21ec2..539c2c4159eed5 100644 --- a/src/controller/CHIPDeviceControllerSystemState.h +++ b/src/controller/CHIPDeviceControllerSystemState.h @@ -57,16 +57,25 @@ namespace chip { inline constexpr size_t kMaxDeviceTransportBlePendingPackets = 1; -using DeviceTransportMgr = TransportMgr /* BLE */ + , + Transport::BLE /* BLE */ +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + Transport::TCP #endif - >; + >; namespace Controller { diff --git a/src/include/platform/internal/GenericPlatformManagerImpl.ipp b/src/include/platform/internal/GenericPlatformManagerImpl.ipp index d6f29ed515e034..64878f5928cd9b 100644 --- a/src/include/platform/internal/GenericPlatformManagerImpl.ipp +++ b/src/include/platform/internal/GenericPlatformManagerImpl.ipp @@ -89,6 +89,16 @@ CHIP_ERROR GenericPlatformManagerImpl::_InitChipStack() } SuccessOrExit(err); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Initialize the CHIP TCP layer. + err = TCPEndPointManager()->Init(SystemLayer()); + if (err != CHIP_NO_ERROR) + { + ChipLogError(DeviceLayer, "TCP initialization failed: %" CHIP_ERROR_FORMAT, err.Format()); + } + SuccessOrExit(err); +#endif + // TODO Perform dynamic configuration of the core CHIP objects based on stored settings. // Initialize the CHIP BLE manager. @@ -132,6 +142,10 @@ void GenericPlatformManagerImpl::_Shutdown() ChipLogError(DeviceLayer, "Inet Layer shutdown"); UDPEndPointManager()->Shutdown(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + TCPEndPointManager()->Shutdown(); +#endif + #if CHIP_DEVICE_CONFIG_ENABLE_CHIPOBLE ChipLogError(DeviceLayer, "BLE shutdown"); BLEMgr().Shutdown(); diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 74a861c1170b7f..24d50c5f4f844d 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -663,5 +663,12 @@ void ExchangeContext::ExchangeSessionHolder::GrabExpiredSession(const SessionHan GrabUnchecked(session); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void ExchangeContext::OnSessionConnectionClosed(CHIP_ERROR conErr) +{ + // TODO: Handle connection closure at the ExchangeContext level. +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index fc20a5aace5273..47cf0ddbef2783 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -86,6 +86,9 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, NewSessionHandlingPolicy GetNewSessionHandlingPolicy() override { return NewSessionHandlingPolicy::kStayAtOldSession; } void OnSessionReleased() override; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void OnSessionConnectionClosed(CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Send a CHIP message on this exchange. * diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 3971864bb7e06e..a184723726781e 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -77,6 +77,9 @@ CHIP_ERROR ExchangeManager::Init(SessionManager * sessionManager) sessionManager->SetMessageDelegate(this); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + sessionManager->SetConnectionDelegate(this); +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT mReliableMessageMgr.Init(sessionManager->SystemLayer()); mState = State::kState_Initialized; @@ -413,5 +416,18 @@ void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * deleg }); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void ExchangeManager::OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) +{ + mContextPool.ForEachActiveObject([&](auto * ec) { + if (ec->HasSessionHandle() && ec->GetSessionHandle() == session) + { + ec->OnSessionConnectionClosed(conErr); + } + return Loop::Continue; + }); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index 48c6f8df673fb6..b6e416cdbf74e7 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -49,6 +49,10 @@ static constexpr int16_t kAnyMessageType = -1; * handling the registration/unregistration of unsolicited message handlers. */ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + public SessionConnectionDelegate +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT { friend class ExchangeContext; @@ -242,6 +246,9 @@ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override; void SendStandaloneAckIfNeeded(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session, MessageFlags msgFlags, System::PacketBufferHandle && msgBuf); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; } // namespace Messaging diff --git a/src/messaging/tests/echo/common.cpp b/src/messaging/tests/echo/common.cpp index 80befc92d27ae6..10eeeb67738024 100644 --- a/src/messaging/tests/echo/common.cpp +++ b/src/messaging/tests/echo/common.cpp @@ -51,10 +51,6 @@ void InitializeChip() err = chip::DeviceLayer::PlatformMgr().InitChipStack(); SuccessOrExit(err); - // Initialize TCP. - err = chip::DeviceLayer::TCPEndPointManager()->Init(chip::DeviceLayer::SystemLayer()); - SuccessOrExit(err); - exit: if (err != CHIP_NO_ERROR) { @@ -68,6 +64,5 @@ void ShutdownChip() gMessageCounterManager.Shutdown(); gExchangeManager.Shutdown(); gSessionManager.Shutdown(); - (void) chip::DeviceLayer::TCPEndPointManager()->Shutdown(); chip::DeviceLayer::PlatformMgr().Shutdown(); } diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index 016ad60358c18a..58dfa1ee980ea9 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -35,7 +35,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include #include @@ -49,6 +51,9 @@ namespace { // Max value for the number of EchoRequests sent. constexpr size_t kMaxEchoCount = 3; +// Max value for the number of tcp connect attempts. +constexpr size_t kMaxTCPConnectAttempts = 3; + // The CHIP Echo interval time. constexpr chip::System::Clock::Timeout gEchoInterval = chip::System::Clock::Seconds16(1); @@ -62,9 +67,21 @@ chip::TransportMgrAsSecureSession()->SetTCPConnection(conn); + } + + printf("Successfully established secure session with peer at %s\n", peerAddrBuf); } return err; } +void CloseConnection() +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + chip::Transport::PeerAddress peerAddr = chip::Transport::PeerAddress::TCP(gDestAddr, CHIP_PORT); + + gSessionManager.TCPDisconnect(peerAddr); + + peerAddr.ToString(peerAddrBuf); + printf("Connection closed to peer at %s\n", peerAddrBuf); + + gClientConEstablished = false; + gClientConInProgress = false; +} + +void HandleConnectionAttemptComplete(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR err) +{ + chip::DeviceLayer::PlatformMgr().StopEventLoopTask(); + + if (err != CHIP_NO_ERROR) + { + printf("Connection FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + gTCPConnAttemptCount++; + return; + } + + err = EstablishSecureSession(conn); + if (err != CHIP_NO_ERROR) + { + printf("Secure session FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + return; + } + + gClientConEstablished = true; + gClientConInProgress = false; +} + +void HandleConnectionClosed(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + CloseConnection(); +} + +void EstablishTCPConnection() +{ + CHIP_ERROR err = CHIP_NO_ERROR; + // Previous connection attempt underway. + if (gClientConInProgress) + { + return; + } + + gClientConEstablished = false; + + chip::Transport::PeerAddress peerAddr = chip::Transport::PeerAddress::TCP(gDestAddr, CHIP_PORT); + + // Connect to the peer + err = gSessionManager.TCPConnect(peerAddr, &gAppTCPConnCbCtxt, &gActiveTCPConnState); + if (err != CHIP_NO_ERROR) + { + printf("Connection FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + gTCPConnAttemptCount++; + return; + } + + gClientConInProgress = true; +} + void HandleEchoResponseReceived(chip::Messaging::ExchangeContext * ec, chip::System::PacketBufferHandle && payload) { chip::System::Clock::Timestamp respTime = chip::System::SystemClock().GetMonotonicTimestamp(); @@ -236,6 +342,10 @@ int main(int argc, char * argv[]) err = gSessionManager.Init(&chip::DeviceLayer::SystemLayer(), &gTCPManager, &gMessageCounterManager, &gStorage, &gFabricTable, gSessionKeystore); SuccessOrExit(err); + + gAppTCPConnCbCtxt.appContext = nullptr; + gAppTCPConnCbCtxt.connCompleteCb = HandleConnectionAttemptComplete; + gAppTCPConnCbCtxt.connClosedCb = HandleConnectionClosed; } else { @@ -255,9 +365,29 @@ int main(int argc, char * argv[]) err = gMessageCounterManager.Init(&gExchangeManager); SuccessOrExit(err); - // Start the CHIP connection to the CHIP echo responder. - err = EstablishSecureSession(); - SuccessOrExit(err); + if (gUseTCP) + { + + while (!gClientConEstablished) + { + // For TCP transport, attempt to establish the connection to the CHIP echo responder. + // On Connection completion, call EstablishSecureSession(conn); + EstablishTCPConnection(); + + chip::DeviceLayer::PlatformMgr().RunEventLoop(); + + if (gTCPConnAttemptCount > kMaxTCPConnectAttempts) + { + ExitNow(); + } + } + } + else + { + // Start the CHIP session to the CHIP echo responder. + err = EstablishSecureSession(nullptr); + SuccessOrExit(err); + } err = gEchoClient.Init(&gExchangeManager, gSession.Get().Value()); SuccessOrExit(err); @@ -274,14 +404,14 @@ int main(int argc, char * argv[]) if (gUseTCP) { - gTCPManager.Disconnect(chip::Transport::PeerAddress::TCP(gDestAddr)); + gTCPManager.TCPDisconnect(chip::Transport::PeerAddress::TCP(gDestAddr)); } gTCPManager.Close(); Shutdown(); exit: - if ((err != CHIP_NO_ERROR) || (gEchoRespCount != kMaxEchoCount)) + if ((err != CHIP_NO_ERROR) || (gEchoRespCount != kMaxEchoCount) || (gTCPConnAttemptCount > kMaxTCPConnectAttempts)) { printf("ChipEchoClient failed: %s\n", chip::ErrorStr(err)); exit(EXIT_FAILURE); diff --git a/src/messaging/tests/echo/echo_responder.cpp b/src/messaging/tests/echo/echo_responder.cpp index 4383479406648c..ffe0837fa82586 100644 --- a/src/messaging/tests/echo/echo_responder.cpp +++ b/src/messaging/tests/echo/echo_responder.cpp @@ -35,7 +35,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include namespace { diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 6a969c09afecdb..177704bf869f31 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -419,6 +419,19 @@ void CASESession::Clear() mPeerNodeId = kUndefinedNodeId; mFabricsTable = nullptr; mFabricIndex = kUndefinedFabricIndex; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Clear the connection callback context to null. + mTCPConnCbCtxt.appContext = nullptr; + mTCPConnCbCtxt.connCompleteCb = nullptr; + mTCPConnCbCtxt.connClosedCb = nullptr; + mTCPConnCbCtxt.connReceivedCb = nullptr; + if (mPeerConnState && mPeerConnState->mConnectionState != Transport::TCPState::kConnected) + { + // Abort the connection if the CASESession is being destroyed and the + // connection is in the middle of being set up. + mSessionManager->TCPDisconnect(mPeerConnState, true /* shouldAbort = true */); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT } void CASESession::InvalidateIfPendingEstablishmentOnFabric(FabricIndex fabricIndex) @@ -446,7 +459,9 @@ CHIP_ERROR CASESession::Init(SessionManager & sessionManager, Credentials::Certi ReturnErrorOnFailure(mCommissioningHash.Begin()); - mDelegate = delegate; + mDelegate = delegate; + mSessionManager = &sessionManager; + ReturnErrorOnFailure(AllocateSecureSession(sessionManager, sessionEvictionHint)); mValidContext.Reset(); @@ -454,6 +469,12 @@ CHIP_ERROR CASESession::Init(SessionManager & sessionManager, Credentials::Certi mValidContext.mRequiredKeyPurposes.Set(KeyPurposeFlags::kServerAuth); mValidContext.mValidityPolicy = policy; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + mTCPConnCbCtxt.appContext = this; + mTCPConnCbCtxt.connCompleteCb = HandleConnectionAttemptComplete; + mTCPConnCbCtxt.connClosedCb = HandleConnectionClosed; + mTCPConnCbCtxt.connReceivedCb = HandleConnectionReceived; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT return CHIP_NO_ERROR; } @@ -497,6 +518,7 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric { MATTER_TRACE_SCOPE("EstablishSession", "CASESession"); CHIP_ERROR err = CHIP_NO_ERROR; + Transport::PeerAddress peerAddress; // Return early on error here, as we have not initialized any state yet ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); @@ -516,6 +538,11 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric // This is to make sure the exchange will get closed if Init() returned an error. mExchangeCtxt.Emplace(*exchangeCtxt); + // Set the PeerAddress in the secure session up front to indicate the + // Transport Type of the session that is being set up. + peerAddress = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); + mSecureSessionHolder->AsSecureSession()->SetPeerAddress(peerAddress); + // From here onwards, let's go to exit on error, as some state might have already // been initialized SuccessOrExit(err); @@ -534,8 +561,18 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric ChipLogProgress(SecureChannel, "Initiating session on local FabricIndex %u from 0x" ChipLogFormatX64 " -> 0x" ChipLogFormatX64, static_cast(mFabricIndex), ChipLogValueX64(mLocalNodeId), ChipLogValueX64(mPeerNodeId)); - err = SendSigma1(); - SuccessOrExit(err); + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + err = sessionManager.TCPConnect(peerAddress, &mTCPConnCbCtxt, &mPeerConnState); + SuccessOrExit(err); +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } + else + { + err = SendSigma1(); + SuccessOrExit(err); + } exit: if (err != CHIP_NO_ERROR) @@ -647,6 +684,74 @@ CHIP_ERROR CASESession::RecoverInitiatorIpk() return CHIP_NO_ERROR; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void CASESession::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR err) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + VerifyOrReturn(conn != nullptr); + // conn->mAppState should not be NULL. SessionManager has already checked + // before calling this callback. + VerifyOrDie(conn->mAppState != nullptr); + + Transport::PeerAddress peerAddr = conn->mPeerAddr; + peerAddr.ToString(peerAddrBuf); + + CASESession * caseSession = reinterpret_cast(conn->mAppState->appContext); + VerifyOrReturn(caseSession != nullptr); + + // Exit and disconnect if connection setup encountered an error. + SuccessOrExit(err); + + ChipLogDetail(SecureChannel, "TCP Connection established with %s before session establishment", peerAddrBuf); + + // Associate the connection with the current unauthenticated session for the + // CASE exchange. + caseSession->mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetTCPConnection(conn); + + // Associate the connection with the current secure session that is being + // set up. + caseSession->mSecureSessionHolder.Get().Value()->AsSecureSession()->SetTCPConnection(conn); + + // Send Sigma1 after connection is established for sessions over TCP + err = caseSession->SendSigma1(); + SuccessOrExit(err); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Connection establishment failed: %s", ErrorStr(err)); + + // Close the underlying connection + caseSession->mSessionManager->TCPDisconnect(conn); + + caseSession->Clear(); + } +} + +void CASESession::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + // conn->mAppState should not be NULL. SessionManager has already checked + // before calling this callback. + VerifyOrDie(conn->mAppState != nullptr); + + CASESession * caseSession = reinterpret_cast(conn->mAppState->appContext); + VerifyOrReturn(caseSession != nullptr); + + ChipLogDetail(SecureChannel, "TCP Connection for this session has closed"); +} + +void CASESession::HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) +{ + VerifyOrReturn(conn != nullptr); + + CASESession * caseSession = reinterpret_cast(conn->mAppState->appContext); + VerifyOrReturn(caseSession != nullptr); + + ChipLogDetail(SecureChannel, "Incoming TCP Connection received"); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR CASESession::SendSigma1() { MATTER_TRACE_SCOPE("SendSigma1", "CASESession"); @@ -2139,9 +2244,11 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST #if CHIP_CONFIG_SLOW_CRYPTO - if (msgType == Protocols::SecureChannel::MsgType::CASE_Sigma1 || msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2 || - msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume || - msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3) + if ((msgType == Protocols::SecureChannel::MsgType::CASE_Sigma1 || msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2 || + msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume || + msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3) && + mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress().GetTransportType() != + Transport::Type::kTcp) { SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks()); } diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 9e41f6c69fbe84..1dad96af71885c 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -286,6 +286,18 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, void InvalidateIfPendingEstablishmentOnFabric(FabricIndex fabricIndex); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + static void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + static void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + static void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn); + + // Context to pass down when connecting to peer + Transport::AppTCPConnectionCallbackCtxt mTCPConnCbCtxt; + // Pointer to the underlying TCP connection state. Returned by the + // TCPConnect() method when an ActiveTCPConenctionState object is allocated. + Transport::ActiveTCPConnectionState * mPeerConnState = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + #if CONFIG_BUILD_FOR_HOST_UNIT_TEST void SetStopSigmaHandshakeAt(Optional state) { mStopHandshakeAtState = state; } #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST @@ -301,6 +313,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, uint8_t mIPK[kIPKSize]; SessionResumptionStorage * mSessionResumptionStorage = nullptr; + SessionManager * mSessionManager = nullptr; FabricTable * mFabricsTable = nullptr; FabricIndex mFabricIndex = kUndefinedFabricIndex; diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 1f7874bdf115dc..93b0f9696704aa 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace chip { @@ -58,6 +59,18 @@ void PairingSession::Finish() { Transport::PeerAddress address = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (address.GetTransportType() == Transport::Type::kTcp) + { + // Fetch the connection for the unauthenticated session used to setup + // the secure session. + Transport::ActiveTCPConnectionState * conn = + mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetTCPConnection(); + + // Associate the connection with the secure session being activated. + mSecureSessionHolder->AsSecureSession()->SetTCPConnection(conn); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT // Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that. DiscardExchange(); diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h b/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h index 7f22c92be0b71c..99dfe10d4e812c 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h @@ -515,7 +515,8 @@ class DLL_EXPORT UserDirectedCommissioningClient : public TransportMgrDelegate } private: - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; CommissionerDeclarationHandler * mCommissionerDeclarationHandler = nullptr; }; @@ -625,7 +626,8 @@ class DLL_EXPORT UserDirectedCommissioningServer : public TransportMgrDelegate InstanceNameResolver * mInstanceNameResolver = nullptr; UserConfirmationProvider * mUserConfirmationProvider = nullptr; - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; UDCClients mUdcClients; // < Active UDC clients diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp index bf7db199ca2d53..131518f2b8125f 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp @@ -24,6 +24,7 @@ */ #include "UserDirectedCommissioning.h" +#include #ifdef __ZEPHYR__ #include @@ -235,7 +236,8 @@ CHIP_ERROR CommissionerDeclaration::ReadPayload(uint8_t * udcPayload, size_t pay return CHIP_NO_ERROR; } -void UserDirectedCommissioningClient::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg) +void UserDirectedCommissioningClient::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { char addrBuffer[chip::Transport::PeerAddress::kMaxToStringSize]; source.ToString(addrBuffer); diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp index cd20c5fc8734ac..a37684e6411516 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp @@ -26,6 +26,7 @@ #include "UserDirectedCommissioning.h" #include #include +#include #include @@ -33,7 +34,8 @@ namespace chip { namespace Protocols { namespace UserDirectedCommissioning { -void UserDirectedCommissioningServer::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg) +void UserDirectedCommissioningServer::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { char addrBuffer[chip::Transport::PeerAddress::kMaxToStringSize]; source.ToString(addrBuffer); diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index fa556d40afd5b0..c4649c9b92d27c 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -37,6 +37,7 @@ static_library("transport") { "SecureSessionTable.h", "Session.cpp", "Session.h", + "SessionConnectionDelegate.h", "SessionDelegate.h", "SessionHolder.cpp", "SessionHolder.h", diff --git a/src/transport/Session.h b/src/transport/Session.h index 0b6048d5c077c0..4e5376c31edc03 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -28,6 +28,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { namespace Transport { @@ -225,6 +228,15 @@ class Session bool IsUnauthenticatedSession() const { return GetSessionType() == SessionType::kUnauthenticated; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // This API is used to associate the connection with the session when the + // latter is about to be marked active. It is also used to reset the + // connection to a nullptr when the connection is lost and the session + // is marked as Defunct. + ActiveTCPConnectionState * GetTCPConnection() const { return mTCPConnection; } + void SetTCPConnection(ActiveTCPConnectionState * conn) { mTCPConnection = conn; } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + void DispatchSessionEvent(SessionDelegate::Event event) { // Holders might remove themselves when notified. @@ -264,6 +276,15 @@ class Session private: FabricIndex mFabricIndex = kUndefinedFabricIndex; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // The underlying TCP connection object over which the session is + // established. + // The lifetime of this member connection pointer is, essentially the same + // as that of the underlying connection with the peer. + // It would remain as a nullptr for all sessions that are not set up over + // a TCP connection. + ActiveTCPConnectionState * mTCPConnection = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; // diff --git a/src/transport/SessionConnectionDelegate.h b/src/transport/SessionConnectionDelegate.h new file mode 100644 index 00000000000000..39769801aa729d --- /dev/null +++ b/src/transport/SessionConnectionDelegate.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { + +/** + * @brief + * This class provides a skeleton for the callback functions related to sessions over TCP, + * when the underlying connection changes. The functions will be called by SessionManager + * object on specific events. If the user of SessionManager is interested in receiving + * these callbacks, they can specialize this class and handle each trigger in their + * implementation of this class. + */ +class DLL_EXPORT SessionConnectionDelegate +{ +public: + virtual ~SessionConnectionDelegate() {} + + /** + * @brief + * Called when the underlying connection for the session is closed. + * + * @param session The handle to the secure session + * @param conErr The connection error code + */ + virtual void OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) = 0; +}; + +} // namespace chip diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h index b9e0a7b8b38b0c..503aaa2b0c4f5c 100644 --- a/src/transport/SessionDelegate.h +++ b/src/transport/SessionDelegate.h @@ -66,6 +66,10 @@ class DLL_EXPORT SessionDelegate * SessionManager to allocate a new session. If they desire to do so, it MUST be done asynchronously. */ virtual void OnSessionHang() {} + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + virtual void OnSessionConnectionClosed(CHIP_ERROR conErr) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; } // namespace chip diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index d38a32bb508b83..f74039ec8b081f 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -122,6 +122,11 @@ CHIP_ERROR SessionManager::Init(System::Layer * systemLayer, TransportMgrBase * mTransportMgr->SetSessionManager(this); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + mConnCompleteCb = nullptr; + mConnClosedCb = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + return CHIP_NO_ERROR; } @@ -568,7 +573,8 @@ CHIP_ERROR SessionManager::InjectCaseSessionWithTestKey(SessionHolder & sessionH return CHIP_NO_ERROR; } -void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg) +void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { PacketHeader partialPacketHeader; @@ -587,20 +593,155 @@ void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System:: } else { - SecureUnicastMessageDispatch(partialPacketHeader, peerAddress, std::move(msg)); + SecureUnicastMessageDispatch(partialPacketHeader, peerAddress, std::move(msg), ctxt); } } else { - UnauthenticatedMessageDispatch(partialPacketHeader, peerAddress, std::move(msg)); + UnauthenticatedMessageDispatch(partialPacketHeader, peerAddress, std::move(msg), ctxt); + } +} + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void SessionManager::HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + + VerifyOrReturn(conn != nullptr); + Transport::PeerAddress peerAddr = conn->mPeerAddr; + peerAddr.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Received TCP connection request from %s.", peerAddrBuf); + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + if (appTCPConnCbCtxt != nullptr && appTCPConnCbCtxt->connReceivedCb != nullptr) + { + appTCPConnCbCtxt->connReceivedCb(conn); + } +} + +void SessionManager::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + + VerifyOrReturn(conn != nullptr); + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + if (appTCPConnCbCtxt == nullptr) + { + TCPDisconnect(conn, true /* shouldAbort = true */); + return; + } + + Transport::PeerAddress peerAddr = conn->mPeerAddr; + peerAddr.ToString(peerAddrBuf); + + if (appTCPConnCbCtxt->connCompleteCb != nullptr) + { + appTCPConnCbCtxt->connCompleteCb(conn, conErr); + } + else + { + ChipLogProgress(Inet, "TCP Connection established with peer %s, but no registered handler. Disconnecting.", peerAddrBuf); + + // Close the connection + TCPDisconnect(conn, true /* shouldAbort = true */); } } +void SessionManager::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + + // Mark the corresponding secure sessions as defunct + mSecureSessions.ForEachSession([&](auto session) { + if (session->IsActiveSession() && session->GetTCPConnection() == conn) + { + SessionHandle handle(*session); + // Notify the SessionConnection delegate of the connection + // closure. + if (mConnDelegate != nullptr) + { + mConnDelegate->OnTCPConnectionClosed(handle, conErr); + } + + // Dis-associate the connection from session by setting it to a + // nullptr. + session->SetTCPConnection(nullptr); + // Mark session as defunct + session->MarkAsDefunct(); + } + + return Loop::Continue; + }); + + // TODO: A mechanism to mark an unauthenticated session as unusable when + // the underlying connection is broken. Issue #32323 + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + VerifyOrReturn(appTCPConnCbCtxt != nullptr); + + if (appTCPConnCbCtxt->connClosedCb != nullptr) + { + appTCPConnCbCtxt->connClosedCb(conn, conErr); + } +} + +CHIP_ERROR SessionManager::TCPConnect(const PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState **peerConnState) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(peerAddrBuf); + if (mTransportMgr != nullptr) + { + ChipLogProgress(Inet, "Connecting over TCP with peer at %s.", peerAddrBuf); + return mTransportMgr->TCPConnect(peerAddress, appState, peerConnState); + } + + ChipLogError(Inet, "The transport manager is not initialized. Unable to connect to peer at %s.", peerAddrBuf); + + return CHIP_ERROR_INCORRECT_STATE; +} + +CHIP_ERROR SessionManager::TCPDisconnect(const PeerAddress & peerAddress) +{ + if (mTransportMgr != nullptr) + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Disconnecting TCP connection from peer at %s.", peerAddrBuf); + mTransportMgr->TCPDisconnect(peerAddress); + } + + return CHIP_NO_ERROR; +} + +void SessionManager::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) +{ + if (mTransportMgr != nullptr) + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + Transport::PeerAddress peerAddr = conn->mPeerAddr; + peerAddr.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Disconnecting TCP connection from peer at %s.", peerAddrBuf); + mTransportMgr->TCPDisconnect(conn, shouldAbort); + } +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partialPacketHeader, - const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) + const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { MATTER_TRACE_SCOPE("Unauthenticated Message Dispatch", "SessionManager"); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (peerAddress.GetTransportType() == Transport::Type::kTcp && ctxt->conn == nullptr) + { + ChipLogError(Inet, "Connection object is missing for received message."); + return; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // Drop unsecured messages with privacy enabled. if (partialPacketHeader.HasPrivacyFlag()) { @@ -626,7 +767,7 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial if (source.HasValue()) { // Assume peer is the initiator, we are the responder. - optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetDefaultMRPConfig()); + optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetDefaultMRPConfig(), peerAddress); if (!optionalSession.HasValue()) { ChipLogError(Inet, "UnauthenticatedSession exhausted"); @@ -636,7 +777,7 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial else { // Assume peer is the responder, we are the initiator. - optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value()); + optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value(), peerAddress); if (!optionalSession.HasValue()) { ChipLogProgress(Inet, "Received unknown unsecure packet for initiator 0x" ChipLogFormatX64, @@ -650,7 +791,27 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial Transport::PeerAddress mutablePeerAddress = peerAddress; CorrectPeerAddressInterfaceID(mutablePeerAddress); unsecuredSession->SetPeerAddress(mutablePeerAddress); + unsecuredSession->SetPeerAddress(peerAddress); SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Associate the unauthenticated session with the connection, if not done already. + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { + Transport::ActiveTCPConnectionState * sessionConn = unsecuredSession->GetTCPConnection(); + if (sessionConn == nullptr) + { + unsecuredSession->SetTCPConnection(ctxt->conn); + } + else + { + if (sessionConn != ctxt->conn) + { + ChipLogError(Inet, "Data received over wrong connection %p. Dropping it!", ctxt->conn); + return; + } + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT unsecuredSession->MarkActiveRx(); @@ -689,13 +850,50 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial } void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPacketHeader, - const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) + const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { MATTER_TRACE_SCOPE("Secure Unicast Message Dispatch", "SessionManager"); CHIP_ERROR err = CHIP_NO_ERROR; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (peerAddress.GetTransportType() == Transport::Type::kTcp && ctxt->conn == nullptr) + { + ChipLogError(Inet, "Connection object is missing for received message."); + return; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + Optional session = mSecureSessions.FindSecureSessionByLocalKey(partialPacketHeader.GetSessionId()); + if (!session.HasValue()) + { + ChipLogError(Inet, "Data received on an unknown session (LSID=%d). Dropping it!", partialPacketHeader.GetSessionId()); + return; + } + + Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Associate the secure session with the connection, if not done already. + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { + Transport::ActiveTCPConnectionState * sessionConn = secureSession->GetTCPConnection(); + if (sessionConn == nullptr) + { + secureSession->SetTCPConnection(ctxt->conn); + secureSession->SetPeerAddress(peerAddress); + } + else + { + if (sessionConn != ctxt->conn) + { + ChipLogError(Inet, "Data received over wrong connection %p. Dropping it!", ctxt->conn); + return; + } + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT PayloadHeader payloadHeader; @@ -717,14 +915,6 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPa return; } - if (!session.HasValue()) - { - ChipLogError(Inet, "Data received on an unknown session (LSID=%d). Dropping it!", packetHeader.GetSessionId()); - return; - } - - Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); - // We need to allow through messages even on sessions that are pending // evictions, because for some cases (UpdateNOC, RemoveFabric, etc) there // can be a single exchange alive on the session waiting for a MRP ack, and @@ -1021,13 +1211,19 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack } Optional SessionManager::FindSecureSessionForNode(ScopedNodeId peerNodeId, - const Optional & type) + const Optional & type, + TransportPayloadCapability transportPayloadCapability) { SecureSession * found = nullptr; - mSecureSessions.ForEachSession([&peerNodeId, &type, &found](auto session) { + mSecureSessions.ForEachSession([&peerNodeId, &type, &found, &transportPayloadCapability](auto session) { if (session->IsActiveSession() && session->GetPeer() == peerNodeId && - (!type.HasValue() || type.Value() == session->GetSecureSessionType())) + (!type.HasValue() || type.Value() == session->GetSecureSessionType()) +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + && (transportPayloadCapability == TransportPayloadCapability::kLargePayload ? (session->GetTCPConnection() != nullptr) + : true) +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + ) { // // Select the active session with the most recent activity to return back to the caller. diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 5f2e6f7603cad0..c1aa95633043f3 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -52,8 +52,30 @@ #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + namespace chip { +/* + * This enum indicates whether a session needs to be established over a + * suitable transport that meets certain payload size requirements for + * transmitted messages. + * + */ +enum class TransportPayloadCapability : uint8_t +{ + kMRPPayload, // Transport requires the maximum payload size to fit within a single + // IPv6 packet(1280 bytes). + kLargePayload, // Transport needs to handle payloads larger than the single IPv6 + // packet, as supported by MRP. The transport of choice, in this + // case, is TCP. + kMRPOrTCPCompatiblePayload // This option provides the ability to use MRP + // as the preferred transport, but use a large + // payload transport if that is already + // available. +}; /** * @brief * Tracks ownership of a encrypted packet buffer. @@ -151,6 +173,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl /// ExchangeManager) void SetMessageDelegate(SessionMessageDelegate * cb) { mCB = cb; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void SetConnectionDelegate(SessionConnectionDelegate * cb) { mConnDelegate = cb; } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // Test-only: create a session on the fly. CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabricIndex, @@ -413,8 +439,34 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * * @param source the source address of the package * @param msgBuf the buffer containing a full CHIP message (except for the optional length field). + * @param ctxt pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const Transport::PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState); + + CHIP_ERROR TCPDisconnect(const Transport::PeerAddress & peerAddress); + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0); + + void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) override; + + void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + // Functors for callbacks into higher layers + using OnTCPConnectionReceivedCallback = void (*)(Transport::ActiveTCPConnectionState * conn); + + using OnTCPConnectionCompleteCallback = void (*)(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + + using OnTCPConnectionClosedCallback = void (*)(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config) @@ -436,8 +488,9 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl // is returned. Otherwise, an Optional with no value set is returned. // // - Optional FindSecureSessionForNode(ScopedNodeId peerNodeId, - const Optional & type = NullOptional); + Optional + FindSecureSessionForNode(ScopedNodeId peerNodeId, const Optional & type = NullOptional, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); using SessionHandleCallback = bool (*)(void * context, SessionHandle & sessionHandle); CHIP_ERROR ForEachSessionHandle(void * context, SessionHandleCallback callback); @@ -477,8 +530,18 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl State mState; // < Initialization state of the object chip::Transport::GroupOutgoingCounters mGroupClientCounter; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + OnTCPConnectionReceivedCallback mConnReceivedCb = nullptr; + OnTCPConnectionCompleteCallback mConnCompleteCb = nullptr; + OnTCPConnectionClosedCallback mConnClosedCb = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + SessionMessageDelegate * mCB = nullptr; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + SessionConnectionDelegate * mConnDelegate = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + TransportMgrBase * mTransportMgr = nullptr; Transport::MessageCounterManagerInterface * mMessageCounterManager = nullptr; @@ -491,9 +554,11 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * If the message decrypts successfully, this will be filled with a fully decoded PacketHeader. * @param[in] peerAddress The PeerAddress of the message as provided by the receiving Transport Endpoint. * @param msg The full message buffer, including header fields. + * @param ctxt The pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ void SecureUnicastMessageDispatch(const PacketHeader & partialPacketHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg); + System::PacketBufferHandle && msg, Transport::MessageTransportContext * ctxt = nullptr); /** * @brief Parse, decrypt, validate, and dispatch a secure group message. @@ -511,9 +576,11 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * @param partialPacketHeader The partial PacketHeader of the message after processing with DecodeFixed. * @param peerAddress The PeerAddress of the message as provided by the receiving Transport Endpoint. * @param msg The full message buffer, including header fields. + * @param ctxt The pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ void UnauthenticatedMessageDispatch(const PacketHeader & partialPacketHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg); + System::PacketBufferHandle && msg, Transport::MessageTransportContext * ctxt = nullptr); void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source); diff --git a/src/transport/TransportMgr.h b/src/transport/TransportMgr.h index 494db39de964fd..2c42c7602bbffe 100644 --- a/src/transport/TransportMgr.h +++ b/src/transport/TransportMgr.h @@ -34,6 +34,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { @@ -49,8 +52,26 @@ class TransportMgrDelegate * * @param source the source address of the package * @param msgBuf the buffer containing a full CHIP message (except for the optional length field). + * @param ctxt the pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ - virtual void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) = 0; + virtual void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) = 0; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * @brief + * Handle connection completion. + * + * @param conn the connection object + * @param conErr the connection error on the attempt. + */ + virtual void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + + virtual void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + + virtual void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn){}; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; template diff --git a/src/transport/TransportMgrBase.cpp b/src/transport/TransportMgrBase.cpp index bde54a96c195b6..8db933a62f85c4 100644 --- a/src/transport/TransportMgrBase.cpp +++ b/src/transport/TransportMgrBase.cpp @@ -28,11 +28,24 @@ CHIP_ERROR TransportMgrBase::SendMessage(const Transport::PeerAddress & address, return mTransport->SendMessage(address, std::move(msgBuf)); } -void TransportMgrBase::Disconnect(const Transport::PeerAddress & address) +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +CHIP_ERROR TransportMgrBase::TCPConnect(const Transport::PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) { - mTransport->Disconnect(address); + return mTransport->TCPConnect(address, appState, peerConnState); } +void TransportMgrBase::TCPDisconnect(const Transport::PeerAddress & address) +{ + mTransport->TCPDisconnect(address); +} + +void TransportMgrBase::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) +{ + mTransport->TCPDisconnect(conn, shouldAbort); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TransportMgrBase::Init(Transport::Base * transport) { if (mTransport != nullptr) @@ -41,6 +54,7 @@ CHIP_ERROR TransportMgrBase::Init(Transport::Base * transport) } mTransport = transport; mTransport->SetDelegate(this); + ChipLogDetail(Inet, "TransportMgr initialized"); return CHIP_NO_ERROR; } @@ -56,7 +70,8 @@ CHIP_ERROR TransportMgrBase::MulticastGroupJoinLeave(const Transport::PeerAddres return mTransport->MulticastGroupJoinLeave(address, join); } -void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) +void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { // This is the first point all incoming messages funnel through. Ensure // that our message receipts are all synchronized correctly. @@ -73,7 +88,7 @@ void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peer if (mSessionManager != nullptr) { - mSessionManager->OnMessageReceived(peerAddress, std::move(msg)); + mSessionManager->OnMessageReceived(peerAddress, std::move(msg), ctxt); } else { @@ -83,4 +98,60 @@ void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peer } } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void TransportMgrBase::HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionReceived(conn); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + + // Close connection here since no upper layer is interested in the + // connection. + if (tcp) + { + tcp->TCPDisconnect(conn, true /* shouldAbort = true */); + } + } +} + +void TransportMgrBase::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionAttemptComplete(conn, conErr); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + + // Close connection here since no upper layer is interested in the + // connection. + if (tcp) + { + tcp->TCPDisconnect(conn, true /* shouldAbort = true */); + } + } +} + +void TransportMgrBase::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionClosed(conn, conErr); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + if (tcp) + { + tcp->TCPDisconnect(conn, true /* shouldAbort = true */); + } + } +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace chip diff --git a/src/transport/TransportMgrBase.h b/src/transport/TransportMgrBase.h index e4942ca6ecd90b..2b0f33bfb9e3a7 100644 --- a/src/transport/TransportMgrBase.h +++ b/src/transport/TransportMgrBase.h @@ -21,6 +21,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { @@ -39,13 +42,29 @@ class TransportMgrBase : public Transport::RawTransportDelegate void Close(); - void Disconnect(const Transport::PeerAddress & address); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const Transport::PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState); + + void TCPDisconnect(const Transport::PeerAddress & address); + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0); + + void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) override; + + void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT void SetSessionManager(TransportMgrDelegate * sessionManager) { mSessionManager = sessionManager; } + TransportMgrDelegate * GetSessionManager() { return mSessionManager; }; + CHIP_ERROR MulticastGroupJoinLeave(const Transport::PeerAddress & address, bool join); - void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) override; + void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt = nullptr) override; private: TransportMgrDelegate * mSessionManager = nullptr; diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index f5ba35df533de5..adb86d1ecfad41 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -45,9 +45,10 @@ class UnauthenticatedSession : public Session, public ReferenceCounted & sessionTable) : - UnauthenticatedSession(sessionRole, ephemeralInitiatorNodeID, config) + UnauthenticatedSession(sessionRole, ephemeralInitiatorNodeID, peerAddress, config) #if CHIP_SYSTEM_CONFIG_POOL_USE_HEAP , mSessionTable(sessionTable) @@ -224,13 +225,16 @@ class UnauthenticatedSessionTable * @return the session found or allocated, or Optional::Missing if not found and allocation failed. */ CHECK_RETURN_VALUE - Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) + Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config, + const Transport::PeerAddress & peerAddress) { - UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID); + UnauthenticatedSession * result = + FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, peerAddress); if (result != nullptr) return MakeOptional(*result); - CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, config, result); + CHIP_ERROR err = + AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, peerAddress, config, result); if (err == CHIP_NO_ERROR) { return MakeOptional(*result); @@ -239,9 +243,11 @@ class UnauthenticatedSessionTable return Optional::Missing(); } - CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID) + CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID, + const Transport::PeerAddress & peerAddress) { - UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID); + UnauthenticatedSession * result = + FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, peerAddress); if (result != nullptr) { return MakeOptional(*result); @@ -254,7 +260,8 @@ class UnauthenticatedSessionTable const ReliableMessageProtocolConfig & config) { UnauthenticatedSession * result = nullptr; - CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, config, result); + CHIP_ERROR err = + AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, peerAddress, config, result); if (err == CHIP_NO_ERROR) { result->SetPeerAddress(peerAddress); @@ -276,9 +283,10 @@ class UnauthenticatedSessionTable */ CHECK_RETURN_VALUE CHIP_ERROR AllocEntry(UnauthenticatedSession::SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, - const ReliableMessageProtocolConfig & config, UnauthenticatedSession *& entry) + const PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config, + UnauthenticatedSession *& entry) { - auto entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config, *this); + auto entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, peerAddress, config, *this); if (entryToUse != nullptr) { entry = entryToUse; @@ -294,7 +302,7 @@ class UnauthenticatedSessionTable // Drop the least recent entry to allow for a new alloc. mEntries.ReleaseObject(entryToUse); - entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config, *this); + entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, peerAddress, config, *this); if (entryToUse == nullptr) { @@ -308,11 +316,13 @@ class UnauthenticatedSessionTable } CHECK_RETURN_VALUE UnauthenticatedSession * FindEntry(UnauthenticatedSession::SessionRole sessionRole, - NodeId ephemeralInitiatorNodeID) + NodeId ephemeralInitiatorNodeID, + const Transport::PeerAddress & peerAddress) { UnauthenticatedSession * result = nullptr; mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) { - if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID) + if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID && + entry->GetPeerAddress().GetTransportType() == peerAddress.GetTransportType()) { result = entry; return Loop::Break; diff --git a/src/transport/raw/ActiveTCPConnectionState.h b/src/transport/raw/ActiveTCPConnectionState.h new file mode 100644 index 00000000000000..2085fa5fce0cd9 --- /dev/null +++ b/src/transport/raw/ActiveTCPConnectionState.h @@ -0,0 +1,119 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP Active Connection object that maintains TCP connections. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +/** + * The State of the TCP connection + */ +enum class TCPState +{ + kNotReady = 0, /**< State before initialization. */ + kInitialized = 1, /**< State after class is listening and ready. */ + kConnecting = 3, /**< Connection with peer has been initiated. */ + kConnected = 4, /**< Connected with peer and ready for Send/Receive. */ + kClosed = 5, /**< Connection is closed. */ +}; + +struct AppTCPConnectionCallbackCtxt; +/** + * State for each active TCP connection + */ +struct ActiveTCPConnectionState +{ + + void Init(Inet::TCPEndPoint * endPoint, const PeerAddress & peerAddr) + { + mEndPoint = endPoint; + mPeerAddr = peerAddr; + mReceived = nullptr; + } + + void Free() + { + mEndPoint->Free(); + mPeerAddr = PeerAddress::Uninitialized(); + mEndPoint = nullptr; + mReceived = nullptr; + } + + bool InUse() const { return mEndPoint != nullptr; } + + // Associated endpoint. + Inet::TCPEndPoint * mEndPoint; + + // Peer Node Address + PeerAddress mPeerAddr; + + // Buffers received but not yet consumed. + System::PacketBufferHandle mReceived; + + // Current state of the connection + TCPState mConnectionState; + + // A pointer to an application-specific state object. It should + // represent an object that is at a layer above the SessionManager. The + // SessionManager would accept this object at the time of connecting to + // the peer, and percolate it down to the TransportManager that then, + // should store this state in the corresponding connection object that + // is created. + // At various connection events, this state is passed back to the + // corresponding application. + AppTCPConnectionCallbackCtxt * mAppState = nullptr; + + // KeepAlive interval in seconds + uint16_t mTCPKeepAliveIntervalSecs = CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS; + uint16_t mTCPMaxNumKeepAliveProbes = CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES; +}; + +// Functors for callbacks into higher layers +using OnTCPConnectionReceivedCallback = void (*)(ActiveTCPConnectionState * conn); + +using OnTCPConnectionCompleteCallback = void (*)(ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +using OnTCPConnectionClosedCallback = void (*)(ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +/* + * Application callback state that is passed down at connection establishment + * stage. + * */ +struct AppTCPConnectionCallbackCtxt +{ + void * appContext = nullptr; // A pointer to an application context object. + OnTCPConnectionReceivedCallback connReceivedCb = nullptr; + OnTCPConnectionCompleteCallback connCompleteCb = nullptr; + OnTCPConnectionClosedCallback connClosedCb = nullptr; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/raw/BUILD.gn b/src/transport/raw/BUILD.gn index 736b16cdb08477..3e8d3d4761dca3 100644 --- a/src/transport/raw/BUILD.gn +++ b/src/transport/raw/BUILD.gn @@ -14,6 +14,7 @@ import("//build_overrides/chip.gni") import("${chip_root}/src/ble/ble.gni") +import("${chip_root}/src/inet/inet.gni") static_library("raw") { output_name = "libRawTransport" @@ -23,13 +24,20 @@ static_library("raw") { "MessageHeader.cpp", "MessageHeader.h", "PeerAddress.h", - "TCP.cpp", - "TCP.h", "Tuple.h", "UDP.cpp", "UDP.h", ] + if (chip_inet_config_enable_tcp_endpoint) { + sources += [ + "ActiveTCPConnectionState.h", + "TCP.cpp", + "TCP.h", + "TCPConfig.h", + ] + } + if (chip_config_network_layer_ble) { sources += [ "BLE.cpp", diff --git a/src/transport/raw/Base.h b/src/transport/raw/Base.h index 66c01a5fcc3d5d..920f932b82be41 100644 --- a/src/transport/raw/Base.h +++ b/src/transport/raw/Base.h @@ -24,20 +24,38 @@ #pragma once #include +#include #include #include #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { namespace Transport { +struct MessageTransportContext +{ +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + ActiveTCPConnectionState * conn = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT +}; + class RawTransportDelegate { public: virtual ~RawTransportDelegate() {} - virtual void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) = 0; + virtual void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + MessageTransportContext * ctxt = nullptr) = 0; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + virtual void HandleConnectionReceived(ActiveTCPConnectionState * conn){}; + virtual void HandleConnectionAttemptComplete(ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + virtual void HandleConnectionClosed(ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; /** @@ -77,10 +95,26 @@ class Base */ virtual bool CanListenMulticast() { return false; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * Connect to the specified peer. + */ + virtual CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return CHIP_NO_ERROR; + } + /** * Handle disconnection from the specified peer if currently connected to it. */ - virtual void Disconnect(const PeerAddress & address) {} + virtual void TCPDisconnect(const PeerAddress & address) {} + + /** + * Disconnect on the active connection that is passed in. + */ + virtual void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Enable Listening for multicast messages ( IPV6 UDP only) @@ -97,12 +131,31 @@ class Base * Method used by subclasses to notify that a packet has been received after * any associated headers have been decoded. */ - void HandleMessageReceived(const PeerAddress & source, System::PacketBufferHandle && buffer) + void HandleMessageReceived(const PeerAddress & source, System::PacketBufferHandle && buffer, + MessageTransportContext * ctxt = nullptr) + { + mDelegate->HandleMessageReceived(source, std::move(buffer), ctxt); + } + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Handle an incoming connection request from a peer. + void HandleConnectionReceived(ActiveTCPConnectionState * conn) { mDelegate->HandleConnectionReceived(conn); } + + // Callback during connection establishment to notify of success or any + // error. + void HandleConnectionAttemptComplete(ActiveTCPConnectionState * conn, CHIP_ERROR conErr) + { + mDelegate->HandleConnectionAttemptComplete(conn, conErr); + } + + // Callback to notify the higher layer of an unexpected connection closure. + void HandleConnectionClosed(ActiveTCPConnectionState * conn, CHIP_ERROR conErr) { - mDelegate->HandleMessageReceived(source, std::move(buffer)); + mDelegate->HandleConnectionClosed(conn, conErr); } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT - RawTransportDelegate * mDelegate; + RawTransportDelegate * mDelegate = nullptr; }; } // namespace Transport diff --git a/src/transport/raw/TCP.cpp b/src/transport/raw/TCP.cpp index a1b1df1fe45a2d..f3506762371d0c 100644 --- a/src/transport/raw/TCP.cpp +++ b/src/transport/raw/TCP.cpp @@ -67,8 +67,7 @@ void TCPBase::CloseActiveConnections() { if (mActiveConnections[i].InUse()) { - mActiveConnections[i].Free(); - mUsedEndPointCount--; + CloseConnectionInternal(&mActiveConnections[i], CHIP_NO_ERROR, SuppressCallback::Yes); } } } @@ -77,7 +76,7 @@ CHIP_ERROR TCPBase::Init(TcpListenParameters & params) { CHIP_ERROR err = CHIP_NO_ERROR; - VerifyOrExit(mState == State::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(mState == TCPState::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); #if INET_CONFIG_ENABLE_TCP_ENDPOINT err = params.GetEndPointManager()->NewEndPoint(&mListenSocket); @@ -90,23 +89,21 @@ CHIP_ERROR TCPBase::Init(TcpListenParameters & params) params.GetInterfaceId().IsPresent()); SuccessOrExit(err); + mListenSocket->mAppState = reinterpret_cast(this); + mListenSocket->OnConnectionReceived = HandleIncomingConnection; + mListenSocket->OnAcceptError = HandleAcceptError; + + mEndpointType = params.GetAddressType(); + err = mListenSocket->Listen(kListenBacklogSize); SuccessOrExit(err); - mListenSocket->mAppState = reinterpret_cast(this); - mListenSocket->OnDataReceived = OnTcpReceive; - mListenSocket->OnConnectComplete = OnConnectionComplete; - mListenSocket->OnConnectionClosed = OnConnectionClosed; - mListenSocket->OnConnectionReceived = OnConnectionReceived; - mListenSocket->OnAcceptError = OnAcceptError; - mEndpointType = params.GetAddressType(); - - mState = State::kInitialized; + mState = TCPState::kInitialized; exit: if (err != CHIP_NO_ERROR) { - ChipLogError(Inet, "Failed to initialize TCP transport: %s", ErrorStr(err)); + ChipLogError(Inet, "Failed to initialize TCP transport: %" CHIP_ERROR_FORMAT, err.Format()); if (mListenSocket) { mListenSocket->Free(); @@ -124,10 +121,24 @@ void TCPBase::Close() mListenSocket->Free(); mListenSocket = nullptr; } - mState = State::kNotReady; + mState = TCPState::kNotReady; +} + +ActiveTCPConnectionState * TCPBase::AllocateConnection() +{ + for (size_t i = 0; i < mActiveConnectionsSize; i++) + { + if (!mActiveConnections[i].InUse()) + { + return &mActiveConnections[i]; + } + } + + return nullptr; } -TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress & address) +// Find an ActiveTCPConnectionState corresponding to a peer address +ActiveTCPConnectionState * TCPBase::FindActiveConnection(const PeerAddress & address) { if (address.GetTransportType() != Type::kTcp) { @@ -153,7 +164,8 @@ TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress return nullptr; } -TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const Inet::TCPEndPoint * endPoint) +// Find the ActiveTCPConnectionState for a given TCPEndPoint +ActiveTCPConnectionState * TCPBase::FindActiveConnection(const Inet::TCPEndPoint * endPoint) { for (size_t i = 0; i < mActiveConnectionsSize; i++) { @@ -172,7 +184,7 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: // - actual data VerifyOrReturnError(address.GetTransportType() == Type::kTcp, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(mState == State::kInitialized, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mState == TCPState::kInitialized, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(kPacketSizeBytes + msgBuf->DataLength() <= std::numeric_limits::max(), CHIP_ERROR_INVALID_ARGUMENT); @@ -186,9 +198,9 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: // Reuse existing connection if one exists, otherwise a new one // will be established - ActiveConnectionState * connection = FindActiveConnection(address); + ActiveTCPConnectionState * connection = FindActiveConnection(address); - if (connection != nullptr) + if (connection != nullptr && connection->mConnectionState == TCPState::kConnected) { return connection->mEndPoint->Send(std::move(msgBuf)); } @@ -196,8 +208,46 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: return SendAfterConnect(address, std::move(msgBuf)); } +CHIP_ERROR TCPBase::StartConnect(const PeerAddress & addr, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) +{ +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + ActiveTCPConnectionState * activeConnection = nullptr; + Inet::TCPEndPoint * endPoint = nullptr; + *peerConnState = nullptr; + ReturnErrorOnFailure(mListenSocket->GetEndPointManager().NewEndPoint(&endPoint)); + + auto EndPointDeletor = [](Inet::TCPEndPoint * e) { e->Free(); }; + std::unique_ptr endPointHolder(endPoint, EndPointDeletor); + + endPoint->mAppState = reinterpret_cast(this); + endPoint->OnConnectComplete = HandleTCPEndPointConnectComplete; + endPoint->SetConnectTimeout(mConnectTimeout); + + activeConnection = AllocateConnection(); + VerifyOrReturnError(activeConnection != nullptr, CHIP_ERROR_NO_MEMORY); + activeConnection->mEndPoint = endPoint; + activeConnection->mAppState = appState; + activeConnection->mConnectionState = TCPState::kConnecting; + // Set the return value of the peer connection state to the allocated + // connection. + *peerConnState = activeConnection; + + ReturnErrorOnFailure(endPoint->Connect(addr.GetIPAddress(), addr.GetPort(), addr.GetInterface())); + + mUsedEndPointCount++; + + endPointHolder.release(); + + return CHIP_NO_ERROR; +#else + return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; +#endif +} + CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBufferHandle && msg) { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT // This will initiate a connection to the specified peer bool alreadyConnecting = false; @@ -224,28 +274,13 @@ CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBuf // Ensures sufficient active connections size exist VerifyOrReturnError(mUsedEndPointCount < mActiveConnectionsSize, CHIP_ERROR_NO_MEMORY); -#if INET_CONFIG_ENABLE_TCP_ENDPOINT - Inet::TCPEndPoint * endPoint = nullptr; - ReturnErrorOnFailure(mListenSocket->GetEndPointManager().NewEndPoint(&endPoint)); - auto EndPointDeletor = [](Inet::TCPEndPoint * e) { e->Free(); }; - std::unique_ptr endPointHolder(endPoint, EndPointDeletor); - - endPoint->mAppState = reinterpret_cast(this); - endPoint->OnDataReceived = OnTcpReceive; - endPoint->OnConnectComplete = OnConnectionComplete; - endPoint->OnConnectionClosed = OnConnectionClosed; - endPoint->OnConnectionReceived = OnConnectionReceived; - endPoint->OnAcceptError = OnAcceptError; - endPoint->OnPeerClose = OnPeerClosed; - - ReturnErrorOnFailure(endPoint->Connect(addr.GetIPAddress(), addr.GetPort(), addr.GetInterface())); + Transport::ActiveTCPConnectionState * peerConnState = nullptr; + ReturnErrorOnFailure(StartConnect(addr, nullptr, &peerConnState)); // enqueue the packet once the connection succeeds VerifyOrReturnError(mPendingPackets.CreateObject(addr, std::move(msg)) != nullptr, CHIP_ERROR_NO_MEMORY); mUsedEndPointCount++; - endPointHolder.release(); - return CHIP_NO_ERROR; #else return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; @@ -255,7 +290,7 @@ CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBuf CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const PeerAddress & peerAddress, System::PacketBufferHandle && buffer) { - ActiveConnectionState * state = FindActiveConnection(endPoint); + ActiveTCPConnectionState * state = FindActiveConnection(endPoint); VerifyOrReturnError(state != nullptr, CHIP_ERROR_INTERNAL); state->mReceived.AddToEnd(std::move(buffer)); @@ -275,6 +310,7 @@ CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const Pe uint16_t messageSize = LittleEndian::Get16(messageSizeBuf); if (messageSize >= kMaxMessageSize) { + // This message is too long for upper layers. return CHIP_ERROR_MESSAGE_TOO_LONG; } @@ -291,12 +327,15 @@ CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const Pe return CHIP_NO_ERROR; } -CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, ActiveConnectionState * state, uint16_t messageSize) +CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, ActiveTCPConnectionState * state, uint16_t messageSize) { // We enter with `state->mReceived` containing at least one full message, perhaps in a chain. // `state->mReceived->Start()` currently points to the message data. // On exit, `state->mReceived` will have had `messageSize` bytes consumed, no matter what. System::PacketBufferHandle message; + MessageTransportContext msgContext; + msgContext.conn = state; + if (state->mReceived->DataLength() == messageSize) { // In this case, the head packet buffer contains exactly the message. @@ -321,23 +360,53 @@ CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, Active message->SetDataLength(messageSize); } - HandleMessageReceived(peerAddress, std::move(message)); + HandleMessageReceived(peerAddress, std::move(message), &msgContext); return CHIP_NO_ERROR; } -void TCPBase::ReleaseActiveConnection(Inet::TCPEndPoint * endPoint) +void TCPBase::CloseConnectionInternal(ActiveTCPConnectionState * connection, CHIP_ERROR err, SuppressCallback suppressCallback) { - for (size_t i = 0; i < mActiveConnectionsSize; i++) + TCPState prevState; + + if (connection == nullptr) { - if (mActiveConnections[i].mEndPoint == endPoint) + return; + } + + if (connection->mConnectionState != TCPState::kClosed && connection->mEndPoint) + { + if (err == CHIP_NO_ERROR) + { + connection->mEndPoint->Close(); + } + else { - mActiveConnections[i].Free(); - mUsedEndPointCount--; + connection->mEndPoint->Abort(); } + + prevState = connection->mConnectionState; + connection->mConnectionState = TCPState::kClosed; + + if (suppressCallback == SuppressCallback::No) + { + if (prevState == TCPState::kConnecting) + { + // Call upper layer connection attempt complete handler + HandleConnectionAttemptComplete(connection, err); + } + else + { + // Call upper layer connection closed handler + HandleConnectionClosed(connection, err); + } + } + + connection->Free(); + mUsedEndPointCount--; } } -CHIP_ERROR TCPBase::OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer) +CHIP_ERROR TCPBase::HandleTCPEndPointDataReceived(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer) { Inet::IPAddress ipAddress; uint16_t port; @@ -353,13 +422,13 @@ CHIP_ERROR TCPBase::OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBuf if (err != CHIP_NO_ERROR) { // Connection could need to be closed at this point - ChipLogError(Inet, "Failed to accept received TCP message: %s", ErrorStr(err)); + ChipLogError(Inet, "Failed to accept received TCP message: %" CHIP_ERROR_FORMAT, err.Format()); return CHIP_ERROR_UNEXPECTED_EVENT; } return CHIP_NO_ERROR; } -void TCPBase::OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR inetErr) +void TCPBase::HandleTCPEndPointConnectComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR conErr) { CHIP_ERROR err = CHIP_NO_ERROR; bool foundPendingPacket = false; @@ -367,119 +436,166 @@ void TCPBase::OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR inet Inet::IPAddress ipAddress; uint16_t port; Inet::InterfaceId interfaceId; + ActiveTCPConnectionState * activeConnection = nullptr; endPoint->GetPeerInfo(&ipAddress, &port); endPoint->GetInterfaceId(&interfaceId); PeerAddress addr = PeerAddress::TCP(ipAddress, port, interfaceId); + char ipAddrStr[Inet::IPAddress::kMaxStringLength]; + ipAddress.ToString(ipAddrStr); - // Send any pending packets - tcp->mPendingPackets.ForEachActiveObject([&](PendingPacket * pending) { - if (pending->mPeerAddress == addr) + if (conErr == CHIP_NO_ERROR) + { + // Set the Data received handler when connection completes + endPoint->OnDataReceived = HandleTCPEndPointDataReceived; + endPoint->OnDataSent = nullptr; + endPoint->OnConnectionClosed = HandleTCPEndPointConnectionClosed; + + activeConnection = tcp->FindActiveConnection(endPoint); + VerifyOrDie(activeConnection != nullptr); + + activeConnection->Init(endPoint, addr); + activeConnection->mConnectionState = TCPState::kConnected; + + // Disable TCP Nagle buffering by setting TCP_NODELAY socket option to true + err = endPoint->EnableNoDelay(); + if (err != CHIP_NO_ERROR) { - foundPendingPacket = true; - System::PacketBufferHandle buffer = std::move(pending->mPacketBuffer); - tcp->mPendingPackets.ReleaseObject(pending); + tcp->CloseConnectionInternal(activeConnection, err, SuppressCallback::No); + return; + } - if ((inetErr == CHIP_NO_ERROR) && (err == CHIP_NO_ERROR)) + // Send any pending packets that are queued for this connection + tcp->mPendingPackets.ForEachActiveObject([&](PendingPacket * pending) { + if (pending->mPeerAddress == addr) { - err = endPoint->Send(std::move(buffer)); + foundPendingPacket = true; + System::PacketBufferHandle buffer = std::move(pending->mPacketBuffer); + tcp->mPendingPackets.ReleaseObject(pending); + + if ((conErr == CHIP_NO_ERROR) && (err == CHIP_NO_ERROR)) + { + err = endPoint->Send(std::move(buffer)); + } } - } - return Loop::Continue; - }); + return Loop::Continue; + }); - if (err == CHIP_NO_ERROR) - { - err = inetErr; - } + // Set the TCPKeepalive configurations on the established connection + endPoint->EnableKeepAlive(activeConnection->mTCPKeepAliveIntervalSecs, activeConnection->mTCPMaxNumKeepAliveProbes); - if (!foundPendingPacket && (err == CHIP_NO_ERROR)) - { - // Force a close: new connections are only expected when a - // new buffer is being sent. - ChipLogError(Inet, "Connection accepted without pending buffers"); - err = CHIP_ERROR_CONNECTION_CLOSED_UNEXPECTEDLY; - } + ChipLogProgress(Inet, "Connection established successfully with [%s:%d]", ipAddrStr, port); - // cleanup packets or mark as free - if (err != CHIP_NO_ERROR) - { - ChipLogError(Inet, "Connection complete encountered an error: %s", ErrorStr(err)); - endPoint->Free(); - tcp->mUsedEndPointCount--; + // Let higher layer/delegate know that connection is successfully + // established + tcp->HandleConnectionAttemptComplete(activeConnection, CHIP_NO_ERROR); } else { - bool connectionStored = false; - for (size_t i = 0; i < tcp->mActiveConnectionsSize; i++) - { - if (!tcp->mActiveConnections[i].InUse()) - { - tcp->mActiveConnections[i].Init(endPoint); - connectionStored = true; - break; - } - } - - // since we track end points counts, we always expect to store the - // connection. - if (!connectionStored) - { - endPoint->Free(); - ChipLogError(Inet, "Internal logic error: insufficient space to store active connection"); - } + ChipLogError(Inet, "Connection establishment with [%s:%d] encountered an error: %" CHIP_ERROR_FORMAT, ipAddrStr, port, + err.Format()); + endPoint->Free(); + tcp->mUsedEndPointCount--; } } -void TCPBase::OnConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +void TCPBase::HandleTCPEndPointConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) { - TCPBase * tcp = reinterpret_cast(endPoint->mAppState); + TCPBase * tcp = reinterpret_cast(endPoint->mAppState); + ActiveTCPConnectionState * activeConnection = tcp->FindActiveConnection(endPoint); + Inet::IPAddress ipAddress; + uint16_t port; + + endPoint->GetPeerInfo(&ipAddress, &port); + + if (activeConnection == nullptr) + { + endPoint->Free(); + return; + } + + if (err == CHIP_NO_ERROR && activeConnection->mConnectionState == TCPState::kConnected) + { + err = CHIP_ERROR_CONNECTION_CLOSED_UNEXPECTEDLY; + } - ChipLogProgress(Inet, "Connection closed."); + tcp->CloseConnectionInternal(activeConnection, err, SuppressCallback::No); - ChipLogProgress(Inet, "Freeing closed connection."); - tcp->ReleaseActiveConnection(endPoint); + char ipAddrStr[Inet::IPAddress::kMaxStringLength]; + ipAddress.ToString(ipAddrStr); + ChipLogProgress(Inet, "Connection closed with [%s:%d]", ipAddrStr, port); } -void TCPBase::OnConnectionReceived(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, - const Inet::IPAddress & peerAddress, uint16_t peerPort) +// Handler for incoming connection requests from peer nodes +void TCPBase::HandleIncomingConnection(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, + const Inet::IPAddress & peerAddress, uint16_t peerPort) { - TCPBase * tcp = reinterpret_cast(listenEndPoint->mAppState); + TCPBase * tcp = reinterpret_cast(listenEndPoint->mAppState); + ActiveTCPConnectionState * activeConnection = nullptr; + Inet::InterfaceId interfaceId; + Inet::IPAddress ipAddress; + uint16_t port; + + endPoint->GetPeerInfo(&ipAddress, &port); + endPoint->GetInterfaceId(&interfaceId); + PeerAddress addr = PeerAddress::TCP(ipAddress, port, interfaceId); if (tcp->mUsedEndPointCount < tcp->mActiveConnectionsSize) { - // have space to use one more (even if considering pending connections) - for (size_t i = 0; i < tcp->mActiveConnectionsSize; i++) - { - if (!tcp->mActiveConnections[i].InUse()) - { - tcp->mActiveConnections[i].Init(endPoint); - tcp->mUsedEndPointCount++; - break; - } - } + activeConnection = tcp->AllocateConnection(); + + endPoint->mAppState = listenEndPoint->mAppState; + endPoint->OnDataReceived = HandleTCPEndPointDataReceived; + endPoint->OnDataSent = nullptr; + endPoint->OnConnectionClosed = HandleTCPEndPointConnectionClosed; - endPoint->mAppState = listenEndPoint->mAppState; - endPoint->OnDataReceived = OnTcpReceive; - endPoint->OnConnectComplete = OnConnectionComplete; - endPoint->OnConnectionClosed = OnConnectionClosed; - endPoint->OnConnectionReceived = OnConnectionReceived; - endPoint->OnAcceptError = OnAcceptError; - endPoint->OnPeerClose = OnPeerClosed; + // By default, disable TCP Nagle buffering by setting TCP_NODELAY socket option to true + endPoint->EnableNoDelay(); + + // Update state for the active connection + activeConnection->Init(endPoint, addr); + tcp->mUsedEndPointCount++; + activeConnection->mConnectionState = TCPState::kConnected; + + char ipAddrStr[Inet::IPAddress::kMaxStringLength]; + peerAddress.ToString(ipAddrStr); + ChipLogProgress(Inet, "Incoming connection established with peer at [%s:%d]", ipAddrStr, port); + + // Call the upper layer handler for incoming connection received. + tcp->HandleConnectionReceived(activeConnection); } else { - ChipLogError(Inet, "Insufficient connection space to accept new connections"); + ChipLogError(Inet, "Insufficient connection space to accept new connections."); endPoint->Free(); + listenEndPoint->OnAcceptError(endPoint, CHIP_ERROR_TOO_MANY_CONNECTIONS); } } -void TCPBase::OnAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +void TCPBase::HandleAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) { - ChipLogError(Inet, "Accept error: %s", ErrorStr(err)); + endPoint->Free(); + ChipLogError(Inet, "Accept error: %" CHIP_ERROR_FORMAT, err.Format()); } -void TCPBase::Disconnect(const PeerAddress & address) +CHIP_ERROR TCPBase::TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) +{ + VerifyOrReturnError(mState == TCPState::kInitialized, CHIP_ERROR_INCORRECT_STATE); + + // Verify that PeerAddress AddressType is TCP + VerifyOrReturnError(address.GetTransportType() == Transport::Type::kTcp, CHIP_ERROR_INVALID_ARGUMENT); + + VerifyOrReturnError(mUsedEndPointCount < mActiveConnectionsSize, CHIP_ERROR_NO_MEMORY); + + ChipLogProgress(Inet, "Con start to peer"); + + ReturnErrorOnFailure(StartConnect(address, appState, peerConnState)); + + return CHIP_NO_ERROR; +} + +void TCPBase::TCPDisconnect(const PeerAddress & address) { // Closes an existing connection for (size_t i = 0; i < mActiveConnectionsSize; i++) @@ -497,20 +613,32 @@ void TCPBase::Disconnect(const PeerAddress & address) // NOTE: this leaves the socket in TIME_WAIT. // Calling Abort() would clean it since SO_LINGER would be set to 0, // however this seems not to be useful. - mActiveConnections[i].Free(); - mUsedEndPointCount--; + CloseConnectionInternal(&mActiveConnections[i], CHIP_NO_ERROR, SuppressCallback::Yes); } } } } -void TCPBase::OnPeerClosed(Inet::TCPEndPoint * endPoint) +void TCPBase::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) { - TCPBase * tcp = reinterpret_cast(endPoint->mAppState); - ChipLogProgress(Inet, "Freeing connection: connection closed by peer"); + if (conn == nullptr) + { + ChipLogError(Inet, "Failed to Disconnect. Passed in Connection is null."); + return; + } - tcp->ReleaseActiveConnection(endPoint); + if (conn->InUse()) + { + if (shouldAbort) + { + CloseConnectionInternal(conn, CHIP_ERROR_CONNECTION_ABORTED, SuppressCallback::Yes); + } + else + { + CloseConnectionInternal(conn, CHIP_NO_ERROR, SuppressCallback::Yes); + } + } } bool TCPBase::HasActiveConnections() const diff --git a/src/transport/raw/TCP.h b/src/transport/raw/TCP.h index d9f78be1771b0f..2df0b9e510dbe7 100644 --- a/src/transport/raw/TCP.h +++ b/src/transport/raw/TCP.h @@ -34,7 +34,9 @@ #include #include #include +#include #include +#include namespace chip { namespace Transport { @@ -96,45 +98,23 @@ struct PendingPacket /** Implements a transport using TCP. */ class DLL_EXPORT TCPBase : public Base { - /** - * The State of the TCP connection - */ - enum class State - { - kNotReady = 0, /**< State before initialization. */ - kInitialized = 1, /**< State after class is listening and ready. */ - }; protected: - /** - * State for each active connection - */ - struct ActiveConnectionState + enum class ShouldAbort : uint8_t { - void Init(Inet::TCPEndPoint * endPoint) - { - mEndPoint = endPoint; - mReceived = nullptr; - } - - void Free() - { - mEndPoint->Free(); - mEndPoint = nullptr; - mReceived = nullptr; - } - bool InUse() const { return mEndPoint != nullptr; } - - // Associated endpoint. - Inet::TCPEndPoint * mEndPoint; + Yes, + No + }; - // Buffers received but not yet consumed. - System::PacketBufferHandle mReceived; + enum class SuppressCallback : uint8_t + { + Yes, + No }; public: using PendingPacketPoolType = PoolInterface; - TCPBase(ActiveConnectionState * activeConnectionsBuffer, size_t bufferSize, PendingPacketPoolType & packetBuffers) : + TCPBase(ActiveTCPConnectionState * activeConnectionsBuffer, size_t bufferSize, PendingPacketPoolType & packetBuffers) : mActiveConnections(activeConnectionsBuffer), mActiveConnectionsSize(bufferSize), mPendingPackets(packetBuffers) { // activeConnectionsBuffer must be initialized by the caller. @@ -153,6 +133,8 @@ class DLL_EXPORT TCPBase : public Base */ CHIP_ERROR Init(TcpListenParameters & params); + void SetConnectTimeout(const uint32_t connTimeoutMsecs) { mConnectTimeout = connTimeoutMsecs; } + /** * Close the open endpoint without destroying the object */ @@ -160,14 +142,29 @@ class DLL_EXPORT TCPBase : public Base CHIP_ERROR SendMessage(const PeerAddress & address, System::PacketBufferHandle && msgBuf) override; - void Disconnect(const PeerAddress & address) override; + CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) override; + + void TCPDisconnect(const PeerAddress & address) override; + + // Close an active connection (corresponding to the passed + // ActiveTCPConnectionState object) + // and release from the pool. + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) override; bool CanSendToPeer(const PeerAddress & address) override { - return (mState == State::kInitialized) && (address.GetTransportType() == Type::kTcp) && + return (mState == TCPState::kInitialized) && (address.GetTransportType() == Type::kTcp) && (address.GetIPAddress().Type() == mEndpointType); } + const Optional GetConnectionPeerAddress(const Inet::TCPEndPoint * con) + { + ActiveTCPConnectionState * activeConState = FindActiveConnection(con); + + return activeConState != nullptr ? MakeOptional(activeConState->mPeerAddr) : Optional::Missing(); + } + /** * Helper method to determine if IO processing is still required for a TCP transport * before everything is cleaned up (socket closing is async, so after calling 'Close' on @@ -183,12 +180,17 @@ class DLL_EXPORT TCPBase : public Base private: friend class TCPTest; + /** + * Allocate an unused connection from the pool + * + */ + ActiveTCPConnectionState * AllocateConnection(); /** * Find an active connection to the given peer or return nullptr if * no active connection exists. */ - ActiveConnectionState * FindActiveConnection(const PeerAddress & addr); - ActiveConnectionState * FindActiveConnection(const Inet::TCPEndPoint * endPoint); + ActiveTCPConnectionState * FindActiveConnection(const PeerAddress & addr); + ActiveTCPConnectionState * FindActiveConnection(const Inet::TCPEndPoint * endPoint); /** * Sends the specified message once a connection has been established. @@ -223,46 +225,63 @@ class DLL_EXPORT TCPBase : public Base * is no other data). * @param[in] messageSize Size of the single message. */ - CHIP_ERROR ProcessSingleMessage(const PeerAddress & peerAddress, ActiveConnectionState * state, uint16_t messageSize); + CHIP_ERROR ProcessSingleMessage(const PeerAddress & peerAddress, ActiveTCPConnectionState * state, uint16_t messageSize); + + /** + * Initiate a connection to the given peer. On connection completion, + * HandleTCPConnectComplete callback would be called. + * + */ + CHIP_ERROR StartConnect(const PeerAddress & addr, AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState); + + /** + * Gracefully Close or Abort a given connection. + * + */ + void CloseConnectionInternal(ActiveTCPConnectionState * connection, CHIP_ERROR err, SuppressCallback suppressCallback); - // Release an active connection (corresponding to the passed TCPEndPoint) - // from the pool. - void ReleaseActiveConnection(Inet::TCPEndPoint * endPoint); + // Close the listening socket endpoint + void CloseListeningSocket(); // Callback handler for TCPEndPoint. TCP message receive handler. // @see TCPEndpoint::OnDataReceivedFunct - static CHIP_ERROR OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer); + static CHIP_ERROR HandleTCPEndPointDataReceived(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer); // Callback handler for TCPEndPoint. Called when a connection has been completed. // @see TCPEndpoint::OnConnectCompleteFunct - static void OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); + static void HandleTCPEndPointConnectComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); // Callback handler for TCPEndPoint. Called when a connection has been closed. // @see TCPEndpoint::OnConnectionClosedFunct - static void OnConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); - - // Callback handler for TCPEndPoint. Callend when a peer closes the connection. - // @see TCPEndpoint::OnPeerCloseFunct - static void OnPeerClosed(Inet::TCPEndPoint * endPoint); + static void HandleTCPEndPointConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); // Callback handler for TCPEndPoint. Called when a connection is received on the listening port. // @see TCPEndpoint::OnConnectionReceivedFunct - static void OnConnectionReceived(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, - const Inet::IPAddress & peerAddress, uint16_t peerPort); + static void HandleIncomingConnection(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, + const Inet::IPAddress & peerAddress, uint16_t peerPort); - // Called on accept error + // Callback handler for handling accept error // @see TCPEndpoint::OnAcceptErrorFunct - static void OnAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); + static void HandleAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); Inet::TCPEndPoint * mListenSocket = nullptr; ///< TCP socket used by the transport Inet::IPAddressType mEndpointType = Inet::IPAddressType::kUnknown; ///< Socket listening type - State mState = State::kNotReady; ///< State of the TCP transport + TCPState mState = TCPState::kNotReady; ///< State of the TCP transport + + // The configured timeout for the connection attempt to the peer, before + // giving up. + uint32_t mConnectTimeout = CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS; + + // The max payload size of data over a TCP connection that is transmissible + // at a time. + uint32_t mMaxTCPPayloadSize = CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES; // Number of active and 'pending connection' endpoints size_t mUsedEndPointCount = 0; // Currently active connections - ActiveConnectionState * mActiveConnections; + ActiveTCPConnectionState * mActiveConnections; const size_t mActiveConnectionsSize; // Data to be sent when connections succeed @@ -277,14 +296,15 @@ class TCP : public TCPBase { for (size_t i = 0; i < kActiveConnectionsSize; ++i) { - mConnectionsBuffer[i].Init(nullptr); + mConnectionsBuffer[i].Init(nullptr, PeerAddress::Uninitialized()); } } + ~TCP() override { mPendingPackets.ReleaseAll(); } private: friend class TCPTest; - TCPBase::ActiveConnectionState mConnectionsBuffer[kActiveConnectionsSize]; + ActiveTCPConnectionState mConnectionsBuffer[kActiveConnectionsSize]; PoolImpl mPendingPackets; }; diff --git a/src/transport/raw/TCPConfig.h b/src/transport/raw/TCPConfig.h new file mode 100644 index 00000000000000..d54a9466b4d294 --- /dev/null +++ b/src/transport/raw/TCPConfig.h @@ -0,0 +1,127 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines default compile-time configuration constants + * for CHIP. + * + * Package integrators that wish to override these values should + * either use preprocessor definitions or create a project- + * specific chipProjectConfig.h header and then assert + * HAVE_CHIPPROJECTCONFIG_H via the package configuration tool + * via --with-chip-project-includes=DIR where DIR is the + * directory that contains the header. + * + * + */ + +#pragma once + +#include + +namespace chip { + +/** + * @def CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS + * + * @brief Maximum Number of TCP connections a device can simultaneously have + */ +#ifndef CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS +#define CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS 4 +#endif + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT && CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS < 1 +#error "If TCP is enabled, the device needs to support at least 1 TCP connection" +#endif + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT && CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS > INET_CONFIG_NUM_TCP_ENDPOINTS +#error "If TCP is enabled, the maximum number of connections cannot exceed the number of tcp endpoints" +#endif + +/** + * @def CHIP_CONFIG_MAX_TCP_PENDING_PACKETS + * + * @brief Maximum Number of outstanding pending packets in the queue before a TCP connection + * needs to be established + */ +#ifndef CHIP_CONFIG_MAX_TCP_PENDING_PACKETS +#define CHIP_CONFIG_MAX_TCP_PENDING_PACKETS 4 +#endif + +/** + * @def CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES + * + * @brief Maximum payload size of a message over a TCP connection + */ +#ifndef CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES +#define CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES 1000000 +#endif + +/** + * @def CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS + * + * @brief + * This defines the default timeout for the TCP connect + * attempt to either succeed or notify the caller of an + * error. + * + */ +#ifndef CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS +#define CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS (10000) +#endif // CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS + +/** + * @def CHIP_CONFIG_KEEPALIVE_INTERVAL_SECS + * + * @brief + * This defines the default interval (in seconds) between + * keepalive probes for a TCP connection. + * This value also controls the time between last data + * packet sent and the transmission of the first keepalive + * probe. + * + */ +#ifndef CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS +#define CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS (25) +#endif // CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS + +/** + * @def CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES + * + * @brief + * This defines the default value for the maximum number of + * keepalive probes for a TCP connection. + * + */ +#ifndef CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES +#define CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES (5) +#endif // CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES + +/** + * @def CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS + * + * @brief + * This defines the default value for the maximum timeout + * of unacknowledged data for a TCP connection. + * + */ +#ifndef CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS +#define CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS (30) +#endif // CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS + +} // namespace chip diff --git a/src/transport/raw/Tuple.h b/src/transport/raw/Tuple.h index 743b9b9b0e09aa..e3e52a171abcda 100644 --- a/src/transport/raw/Tuple.h +++ b/src/transport/raw/Tuple.h @@ -93,7 +93,20 @@ class Tuple : public Base bool CanSendToPeer(const PeerAddress & address) override { return CanSendToPeerImpl<0>(address); } - void Disconnect(const PeerAddress & address) override { return DisconnectImpl<0>(address); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) override + { + return TCPConnectImpl<0>(address, appState, peerConnState); + } + + void TCPDisconnect(const PeerAddress & address) override { return TCPDisconnectImpl<0>(address); } + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) override + { + return TCPDisconnectImpl<0>(conn, shouldAbort); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT void Close() override { return CloseImpl<0>(); } @@ -138,26 +151,78 @@ class Tuple : public Base return false; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * Recursive TCPConnect implementation iterating through transport members. + * + * @tparam N the index of the underlying transport to send disconnect to + * + * @param address what address to connect to. + */ + template ::type * = nullptr> + CHIP_ERROR TCPConnectImpl(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + Base * base = &std::get(mTransports); + if (base->CanSendToPeer(address)) + { + return base->TCPConnect(address, appState, peerConnState); + } + return TCPConnectImpl(address, appState, peerConnState); + } + + /** + * TCPConnectImpl template for out of range N. + */ + template = sizeof...(TransportTypes))>::type * = nullptr> + CHIP_ERROR TCPConnectImpl(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return CHIP_ERROR_NO_MESSAGE_HANDLER; + } + /** * Recursive disconnect implementation iterating through transport members. * * @tparam N the index of the underlying transport to send disconnect to * - * @param address what address to check. + * @param address what address to disconnect from. + */ + template ::type * = nullptr> + void TCPDisconnectImpl(const PeerAddress & address) + { + std::get(mTransports).TCPDisconnect(address); + TCPDisconnectImpl(address); + } + + /** + * TCPDisconnectImpl template for out of range N. + */ + template = sizeof...(TransportTypes))>::type * = nullptr> + void TCPDisconnectImpl(const PeerAddress & address) + {} + + /** + * Recursive disconnect implementation iterating through transport members. + * + * @tparam N the index of the underlying transport to send disconnect to + * + * @param conn pointer to the connection to the peer. */ template ::type * = nullptr> - void DisconnectImpl(const PeerAddress & address) + void TCPDisconnectImpl(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) { - std::get(mTransports).Disconnect(address); - DisconnectImpl(address); + std::get(mTransports).TCPDisconnect(conn, shouldAbort); + TCPDisconnectImpl(conn, shouldAbort); } /** - * DisconnectImpl template for out of range N. + * TCPDisconnectImpl template for out of range N. */ template = sizeof...(TransportTypes))>::type * = nullptr> - void DisconnectImpl(const PeerAddress & address) + void TCPDisconnectImpl(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Recursive disconnect implementation iterating through transport members. diff --git a/src/transport/raw/tests/TestTCP.cpp b/src/transport/raw/tests/TestTCP.cpp index 620e97ca4325b3..93bbdcab3ecdcc 100644 --- a/src/transport/raw/tests/TestTCP.cpp +++ b/src/transport/raw/tests/TestTCP.cpp @@ -23,6 +23,7 @@ #include "NetworkTestHelpers.h" +#include #include #include #include @@ -32,7 +33,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include @@ -62,6 +65,9 @@ namespace { constexpr size_t kMaxTcpActiveConnectionCount = 4; constexpr size_t kMaxTcpPendingPackets = 4; constexpr uint16_t kPacketSizeBytes = static_cast(sizeof(uint16_t)); +uint16_t gChipTCPPort; +chip::Transport::AppTCPConnectionCallbackCtxt gAppTCPConnCbCtxt; +chip::Transport::ActiveTCPConnectionState * gActiveTCPConnState = nullptr; using TCPImpl = Transport::TCP; @@ -88,7 +94,8 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate mCallback = callback; mCallbackData = callback_data; } - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * transCtxt = nullptr) override { PacketHeader packetHeader; @@ -101,12 +108,56 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate NL_TEST_ASSERT(mSuite, err == 0); } + ChipLogProgress(Inet, "Message Receive Handler called"); + mReceiveHandlerCallCount++; } + void HandleConnectionAttemptComplete(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override + { + chip::Transport::AppTCPConnectionCallbackCtxt * appConnCbCtxt = nullptr; + VerifyOrReturn(conn != nullptr); + + appConnCbCtxt = conn->mAppState; + VerifyOrReturn(appConnCbCtxt != nullptr); + + mHandleConnectionCompleteCalled = true; + + if (appConnCbCtxt->connCompleteCb != nullptr) + { + appConnCbCtxt->connCompleteCb(conn, conErr); + } + else + { + ChipLogProgress(Inet, "Connection established. App callback missing."); + } + } + + void HandleConnectionClosed(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override + { + chip::Transport::AppTCPConnectionCallbackCtxt * appConnCbCtxt = nullptr; + VerifyOrReturn(conn != nullptr); + + mHandleConnectionCloseCalled = true; + + appConnCbCtxt = conn->mAppState; + VerifyOrReturn(appConnCbCtxt != nullptr); + + if (appConnCbCtxt->connClosedCb != nullptr) + { + appConnCbCtxt->connClosedCb(conn, conErr); + } + else + { + ChipLogProgress(Inet, "Connection Closed. App callback missing."); + } + } + void InitializeMessageTest(TCPImpl & tcp, const IPAddress & addr) { - CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(mContext.GetTCPEndPointManager()).SetAddressType(addr.Type())); + CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(mContext.GetTCPEndPointManager()) + .SetAddressType(addr.Type()) + .SetListenPort(gChipTCPPort)); // retry a few times in case the port is somehow in use. // this is a WORKAROUND for flaky testing if we run tests very fast after each other. @@ -125,7 +176,9 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate { ChipLogProgress(NotSpecified, "RETRYING tcp initialization"); chip::test_utils::SleepMillis(100); - err = tcp.Init(Transport::TcpListenParameters(mContext.GetTCPEndPointManager()).SetAddressType(addr.Type())); + err = tcp.Init(Transport::TcpListenParameters(mContext.GetTCPEndPointManager()) + .SetAddressType(addr.Type()) + .SetListenPort(gChipTCPPort)); } NL_TEST_ASSERT(mSuite, err == CHIP_NO_ERROR); @@ -133,7 +186,14 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate mTransportMgrBase.SetSessionManager(this); mTransportMgrBase.Init(&tcp); - mReceiveHandlerCallCount = 0; + mReceiveHandlerCallCount = 0; + mHandleConnectionCompleteCalled = false; + mHandleConnectionCloseCalled = false; + + gAppTCPConnCbCtxt.appContext = nullptr; + gAppTCPConnCbCtxt.connReceivedCb = nullptr; + gAppTCPConnCbCtxt.connCompleteCb = nullptr; + gAppTCPConnCbCtxt.connClosedCb = nullptr; } void SingleMessageTest(TCPImpl & tcp, const IPAddress & addr) @@ -151,7 +211,7 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate NL_TEST_ASSERT(mSuite, err == CHIP_NO_ERROR); // Should be able to send a message to itself by just calling send. - err = tcp.SendMessage(Transport::PeerAddress::TCP(addr), std::move(buffer)); + err = tcp.SendMessage(Transport::PeerAddress::TCP(addr, gChipTCPPort), std::move(buffer)); NL_TEST_ASSERT(mSuite, err == CHIP_NO_ERROR); mContext.DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mReceiveHandlerCallCount != 0; }); @@ -160,21 +220,80 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate SetCallback(nullptr); } - void FinalizeMessageTest(TCPImpl & tcp, const IPAddress & addr) + void ConnectTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection + tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + mContext.DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return tcp.HasActiveConnections(); }); + NL_TEST_ASSERT(mSuite, tcp.HasActiveConnections() == true); + } + + void HandleConnectCompleteCbCalledTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection and connection complete + // handler being called. + tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + mContext.DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mHandleConnectionCompleteCalled; }); + NL_TEST_ASSERT(mSuite, mHandleConnectionCompleteCalled == true); + } + + void HandleConnectCloseCbCalledTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection and connection complete + // handler being called. + tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + mContext.DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mHandleConnectionCompleteCalled; }); + NL_TEST_ASSERT(mSuite, mHandleConnectionCompleteCalled == true); + tcp.TCPDisconnect(Transport::PeerAddress::TCP(addr, gChipTCPPort)); + mContext.DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + NL_TEST_ASSERT(mSuite, mHandleConnectionCloseCalled == true); + } + + void DisconnectTest(TCPImpl & tcp, const IPAddress & addr) { // Disconnect and wait for seeing peer close - tcp.Disconnect(Transport::PeerAddress::TCP(addr)); + tcp.TCPDisconnect(Transport::PeerAddress::TCP(addr, gChipTCPPort)); mContext.DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + NL_TEST_ASSERT(mSuite, tcp.HasActiveConnections() == false); + } + + CHIP_ERROR TCPConnect(const Transport::PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return mTransportMgrBase.TCPConnect(peerAddress, appState, peerConnState); + } + + using OnTCPConnectionReceivedCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn); + + using OnTCPConnectionCompleteCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn, + CHIP_ERROR conErr); + + using OnTCPConnectionClosedCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn, + CHIP_ERROR conErr); + + void SetConnectionCallbacks(OnTCPConnectionCompleteCallback connCompleteCb, OnTCPConnectionClosedCallback connClosedCb, + OnTCPConnectionReceivedCallback connReceivedCb) + { + mConnCompleteCb = connCompleteCb; + mConnClosedCb = connClosedCb; + mConnReceivedCb = connReceivedCb; } int mReceiveHandlerCallCount = 0; + bool mHandleConnectionCompleteCalled = false; + + bool mHandleConnectionCloseCalled = false; + private: nlTestSuite * mSuite; TestContext & mContext; MessageReceivedCallback mCallback; void * mCallbackData; TransportMgrBase mTransportMgrBase; + OnTCPConnectionCompleteCallback mConnCompleteCb = nullptr; + OnTCPConnectionClosedCallback mConnClosedCb = nullptr; + OnTCPConnectionReceivedCallback mConnReceivedCb = nullptr; }; /////////////////////////// Init test @@ -185,7 +304,8 @@ void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext, Inet::IPAddres TCPImpl tcp; - CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(ctx.GetTCPEndPointManager()).SetAddressType(type)); + CHIP_ERROR err = + tcp.Init(Transport::TcpListenParameters(ctx.GetTCPEndPointManager()).SetAddressType(type).SetListenPort(gChipTCPPort)); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); } @@ -212,7 +332,52 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext, const IPAddress & MockTransportMgrDelegate gMockTransportMgrDelegate(inSuite, ctx); gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); - gMockTransportMgrDelegate.FinalizeMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); +} + +void ConnectToSelfTest(nlTestSuite * inSuite, void * inContext, const IPAddress & addr) +{ + TestContext & ctx = *reinterpret_cast(inContext); + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(inSuite, ctx); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.ConnectTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); +} + +void ConnectSendMessageThenCloseTest(nlTestSuite * inSuite, void * inContext, const IPAddress & addr) +{ + TestContext & ctx = *reinterpret_cast(inContext); + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(inSuite, ctx); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.ConnectTest(tcp, addr); + gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); +} + +void HandleConnCompleteTest(nlTestSuite * inSuite, void * inContext, const IPAddress & addr) +{ + TestContext & ctx = *reinterpret_cast(inContext); + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(inSuite, ctx); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.HandleConnectCompleteCbCalledTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); +} + +void HandleConnCloseTest(nlTestSuite * inSuite, void * inContext, const IPAddress & addr) +{ + TestContext & ctx = *reinterpret_cast(inContext); + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(inSuite, ctx); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.HandleConnectCloseCbCalledTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); } #if INET_CONFIG_ENABLE_IPV4 @@ -231,6 +396,57 @@ void CheckMessageTest6(nlTestSuite * inSuite, void * inContext) CheckMessageTest(inSuite, inContext, addr); } +#if INET_CONFIG_ENABLE_IPV4 +void ConnectToSelfTest4(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + ConnectToSelfTest(inSuite, inContext, addr); +} + +void ConnectSendMessageThenCloseTest4(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + ConnectSendMessageThenCloseTest(inSuite, inContext, addr); +} + +void HandleConnCompleteCalledTest4(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + HandleConnCompleteTest(inSuite, inContext, addr); +} +#endif // INET_CONFIG_ENABLE_IPV4 + +void ConnectToSelfTest6(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + ConnectToSelfTest(inSuite, inContext, addr); +} + +void ConnectSendMessageThenCloseTest6(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + ConnectSendMessageThenCloseTest(inSuite, inContext, addr); +} + +void HandleConnCompleteCalledTest6(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + HandleConnCompleteTest(inSuite, inContext, addr); +} + +void HandleConnCloseCalledTest6(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + HandleConnCloseTest(inSuite, inContext, addr); +} + // Generates a packet buffer or a chain of packet buffers for a single message. struct TestData { @@ -397,8 +613,8 @@ void chip::Transport::TCPTest::CheckProcessReceivedBuffer(nlTestSuite * inSuite, // (The current TCPEndPoint implementation is not effectively mockable.) gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); - Transport::PeerAddress lPeerAddress = Transport::PeerAddress::TCP(addr); - TCPBase::ActiveConnectionState * state = tcp.FindActiveConnection(lPeerAddress); + Transport::PeerAddress lPeerAddress = Transport::PeerAddress::TCP(addr, gChipTCPPort); + ActiveTCPConnectionState * state = tcp.FindActiveConnection(lPeerAddress); NL_TEST_ASSERT(inSuite, state != nullptr); Inet::TCPEndPoint * lEndPoint = state->mEndPoint; NL_TEST_ASSERT(inSuite, lEndPoint != nullptr); @@ -449,7 +665,7 @@ void chip::Transport::TCPTest::CheckProcessReceivedBuffer(nlTestSuite * inSuite, NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_TOO_LONG); NL_TEST_ASSERT(inSuite, gMockTransportMgrDelegate.mReceiveHandlerCallCount == 0); - gMockTransportMgrDelegate.FinalizeMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); } // Test Suite @@ -460,12 +676,19 @@ void chip::Transport::TCPTest::CheckProcessReceivedBuffer(nlTestSuite * inSuite, static const nlTest sTests[] = { #if INET_CONFIG_ENABLE_IPV4 - NL_TEST_DEF("Simple Init Test IPV4", CheckSimpleInitTest4), - NL_TEST_DEF("Message Self Test IPV4", CheckMessageTest4), + NL_TEST_DEF("Simple Init Test IPV4", CheckSimpleInitTest4), + NL_TEST_DEF("Message Self Test IPV4", CheckMessageTest4), + NL_TEST_DEF("Connect To Self IPV4", ConnectToSelfTest4), + NL_TEST_DEF("Connect, Send Message Then Close IPV4", ConnectSendMessageThenCloseTest4), + NL_TEST_DEF("HandleConnectionAttemptComplete called IPV4", HandleConnCompleteCalledTest4), #endif - NL_TEST_DEF("Simple Init Test IPV6", CheckSimpleInitTest6), - NL_TEST_DEF("Message Self Test IPV6", CheckMessageTest6), + NL_TEST_DEF("Simple Init Test IPV6", CheckSimpleInitTest6), + NL_TEST_DEF("Message Self Test IPV6", CheckMessageTest6), + NL_TEST_DEF("Connect To Self IPV6", ConnectToSelfTest6), + NL_TEST_DEF("Connect, Send Message Then Close IPV6", ConnectSendMessageThenCloseTest6), + NL_TEST_DEF("HandleConnectionAttemptComplete called IPV6", HandleConnCompleteCalledTest6), + NL_TEST_DEF("HandleConnectionClose called IPV6", HandleConnCloseCalledTest6), NL_TEST_DEF("ProcessReceivedBuffer Test", chip::Transport::TCPTest::CheckProcessReceivedBuffer), NL_TEST_SENTINEL() @@ -487,6 +710,7 @@ static nlTestSuite sSuite = */ static int Initialize(void * aContext) { + gChipTCPPort = static_cast(CHIP_PORT + chip::Crypto::GetRandU16() % 100); CHIP_ERROR err = reinterpret_cast(aContext)->Init(); return (err == CHIP_NO_ERROR) ? SUCCESS : FAILURE; } diff --git a/src/transport/raw/tests/TestUDP.cpp b/src/transport/raw/tests/TestUDP.cpp index b26d23277d5844..ed416a57a6116f 100644 --- a/src/transport/raw/tests/TestUDP.cpp +++ b/src/transport/raw/tests/TestUDP.cpp @@ -57,7 +57,8 @@ class MockTransportMgrDelegate : public TransportMgrDelegate MockTransportMgrDelegate(nlTestSuite * inSuite) : mSuite(inSuite) {} ~MockTransportMgrDelegate() override {} - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * transCtxt = nullptr) override { PacketHeader packetHeader;