From 92493554ea677eb26e9ddfada96f7b54f913be86 Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Sat, 8 May 2021 17:50:05 +0800 Subject: [PATCH] Remove ExchangeMessageDispatch.h (Fix #6456) Major changes: * ExchangeManager and ExchangeContext will handle both secure and unsecure messages * ExchangeContext will be explicitly flagged as secure or unsecure * Secure context will only handle secure messages * Unsecure context will only handle Unsecure messages * Add a simple Variant (src/lib/support/Variant.h) * Secure and unsecure context have diffent member fields. Use Variant to manage these fields. * Use RAII for ExchangeContext contruction and destruction. * Use contructor and destructor of ExchangeContext instread of Alloc/Free. * Remove ExchangeMessageDispatch base class and all its derived classes * Remove ApplicationExchangeDispatch * Remove SessionEstablishmentExchangeDispatch Future works: * Enforce secure or unsecure exchange creation by solicited message handler. * Enable CRMP for unsecure exchanges. --- examples/shell/shell_common/cmd_send.cpp | 2 +- src/app/CommandSender.cpp | 2 +- src/app/ReadClient.cpp | 2 +- src/app/server/RendezvousServer.cpp | 3 - src/app/tests/TestCommandInteraction.cpp | 4 +- src/app/tests/TestReportingEngine.cpp | 2 +- src/app/util/chip-message-send.cpp | 2 +- src/channel/ChannelContext.cpp | 9 +- src/controller/CHIPDevice.cpp | 2 +- src/controller/CHIPDeviceController.cpp | 6 +- src/messaging/ApplicationExchangeDispatch.cpp | 74 ------ src/messaging/ApplicationExchangeDispatch.h | 64 ----- src/messaging/BUILD.gn | 4 - src/messaging/ExchangeContext.cpp | 245 ++++++++++++------ src/messaging/ExchangeContext.h | 126 ++++++--- src/messaging/ExchangeDelegate.h | 23 +- src/messaging/ExchangeMessageDispatch.cpp | 134 ---------- src/messaging/ExchangeMessageDispatch.h | 73 ------ src/messaging/ExchangeMgr.cpp | 230 ++++++++++------ src/messaging/ExchangeMgr.h | 62 ++--- src/messaging/ReliableMessageContext.h | 4 + src/messaging/ReliableMessageMgr.cpp | 9 +- src/messaging/ReliableMessageMgr.h | 13 +- src/messaging/tests/MessagingContext.cpp | 23 +- src/messaging/tests/MessagingContext.h | 7 +- src/messaging/tests/TestExchangeMgr.cpp | 6 +- .../tests/TestReliableMessageProtocol.cpp | 8 +- src/protocols/echo/EchoClient.cpp | 2 +- src/protocols/secure_channel/BUILD.gn | 2 - src/protocols/secure_channel/CASESession.cpp | 2 - src/protocols/secure_channel/CASESession.h | 11 +- .../secure_channel/MessageCounterManager.cpp | 2 +- src/protocols/secure_channel/PASESession.cpp | 2 - src/protocols/secure_channel/PASESession.h | 14 +- .../SessionEstablishmentExchangeDispatch.cpp | 86 ------ .../SessionEstablishmentExchangeDispatch.h | 71 ----- .../secure_channel/tests/TestCASESession.cpp | 10 +- .../tests/TestMessageCounterManager.cpp | 2 +- .../secure_channel/tests/TestPASESession.cpp | 11 +- src/transport/SecureSessionMgr.cpp | 4 +- src/transport/SecureSessionMgr.h | 23 +- src/transport/tests/TestSecureSessionMgr.cpp | 6 +- 42 files changed, 526 insertions(+), 861 deletions(-) delete mode 100644 src/messaging/ApplicationExchangeDispatch.cpp delete mode 100644 src/messaging/ApplicationExchangeDispatch.h delete mode 100644 src/messaging/ExchangeMessageDispatch.cpp delete mode 100644 src/messaging/ExchangeMessageDispatch.h delete mode 100644 src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp delete mode 100644 src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h diff --git a/examples/shell/shell_common/cmd_send.cpp b/examples/shell/shell_common/cmd_send.cpp index 42ab92dab69998..9280eee186c276 100644 --- a/examples/shell/shell_common/cmd_send.cpp +++ b/examples/shell/shell_common/cmd_send.cpp @@ -142,7 +142,7 @@ CHIP_ERROR SendMessage(streamer_t * stream) } // Create a new exchange context. - gExchangeCtx = gExchangeManager.NewContext({ kTestDeviceNodeId, 0, gAdminId }, &gMockAppDelegate); + gExchangeCtx = gExchangeManager.NewSecureContext({ kTestDeviceNodeId, 0, gAdminId }, &gMockAppDelegate); VerifyOrExit(gExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); size = gSendArguments.GetPayloadSize(); diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp index 4a87e0fa7bc103..af246b83e095bd 100644 --- a/src/app/CommandSender.cpp +++ b/src/app/CommandSender.cpp @@ -46,7 +46,7 @@ CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, Transport::AdminId // Create a new exchange context. // TODO: temprary create a SecureSessionHandle from node id, will be fix in PR 3602 // TODO: Hard code keyID to 0 to unblock IM end-to-end test. Complete solution is tracked in issue:4451 - mpExchangeCtx = mpExchangeMgr->NewContext({ aNodeId, 0, aAdminId }, this); + mpExchangeCtx = mpExchangeMgr->NewSecureContext({ aNodeId, 0, aAdminId }, this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(kImMessageTimeoutMsec); diff --git a/src/app/ReadClient.cpp b/src/app/ReadClient.cpp index 3ff2d1c01226c9..a8233f199a012d 100644 --- a/src/app/ReadClient.cpp +++ b/src/app/ReadClient.cpp @@ -164,7 +164,7 @@ CHIP_ERROR ReadClient::SendReadRequest(NodeId aNodeId, Transport::AdminId aAdmin SuccessOrExit(err); } - mpExchangeCtx = mpExchangeMgr->NewContext({ aNodeId, 0, aAdminId }, this); + mpExchangeCtx = mpExchangeMgr->NewSecureContext({ aNodeId, 0, aAdminId }, this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(kImMessageTimeoutMsec); diff --git a/src/app/server/RendezvousServer.cpp b/src/app/server/RendezvousServer.cpp index 3058287bc8d07d..640f7d9069507b 100644 --- a/src/app/server/RendezvousServer.cpp +++ b/src/app/server/RendezvousServer.cpp @@ -81,9 +81,6 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params, strlen(kSpake2pKeyExchangeSalt), mNextKeyId++, this)); } - ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr)); - mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); - return CHIP_NO_ERROR; } diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index fe5ac7dd5c90f9..a99245d6522c16 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -205,7 +205,7 @@ void TestCommandInteraction::TestCommandHandlerWithSendEmptyCommand(nlTestSuite err = commandHandler.Init(&chip::gExchangeManager, nullptr); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - commandHandler.mpExchangeCtx = gExchangeManager.NewContext({ 0, 0, 0 }, nullptr); + commandHandler.mpExchangeCtx = gExchangeManager.NewSecureContext({ 0, 0, 0 }, nullptr); TestExchangeDelegate delegate; commandHandler.mpExchangeCtx->SetDelegate(&delegate); @@ -242,7 +242,7 @@ void TestCommandInteraction::ValidateCommandHandlerWithSendCommand(nlTestSuite * err = commandHandler.Init(&chip::gExchangeManager, nullptr); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - commandHandler.mpExchangeCtx = gExchangeManager.NewContext({ 0, 0, 0 }, nullptr); + commandHandler.mpExchangeCtx = gExchangeManager.NewSecureContext({ 0, 0, 0 }, nullptr); TestExchangeDelegate delegate; commandHandler.mpExchangeCtx->SetDelegate(&delegate); diff --git a/src/app/tests/TestReportingEngine.cpp b/src/app/tests/TestReportingEngine.cpp index 2a0542018078ad..ba176da7210591 100644 --- a/src/app/tests/TestReportingEngine.cpp +++ b/src/app/tests/TestReportingEngine.cpp @@ -109,7 +109,7 @@ void TestReportingEngine::TestBuildAndSendSingleReportData(nlTestSuite * apSuite err = InteractionModelEngine::GetInstance()->Init(&gExchangeManager, nullptr); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - Messaging::ExchangeContext * exchangeCtx = gExchangeManager.NewContext({ 0, 0, 0 }, nullptr); + Messaging::ExchangeContext * exchangeCtx = gExchangeManager.NewSecureContext({ 0, 0, 0 }, nullptr); TestExchangeDelegate delegate; exchangeCtx->SetDelegate(&delegate); diff --git a/src/app/util/chip-message-send.cpp b/src/app/util/chip-message-send.cpp index c5fab58b4e6068..083aac9f1e1c8f 100644 --- a/src/app/util/chip-message-send.cpp +++ b/src/app/util/chip-message-send.cpp @@ -108,7 +108,7 @@ EmberStatus chipSendUnicast(NodeId destination, EmberApsFrame * apsFrame, uint16 return EMBER_DELIVERY_FAILED; } - Messaging::ExchangeContext * exchange = exchangeMgr->NewContext({ destination, Transport::kAnyKeyId, 0 }, nullptr); + Messaging::ExchangeContext * exchange = exchangeMgr->NewSecureContext({ destination, Transport::kAnyKeyId, 0 }, nullptr); if (exchange == nullptr) { return EMBER_DELIVERY_FAILED; diff --git a/src/channel/ChannelContext.cpp b/src/channel/ChannelContext.cpp index 97798f604b72c7..49b71a0c211223 100644 --- a/src/channel/ChannelContext.cpp +++ b/src/channel/ChannelContext.cpp @@ -38,7 +38,7 @@ void ChannelContext::Start(const ChannelBuilder & builder) ExchangeContext * ChannelContext::NewExchange(ExchangeDelegate * delegate) { assert(GetState() == ChannelState::kReady); - return mExchangeManager->NewContext(GetReadyVars().mSession, delegate); + return mExchangeManager->NewSecureContext(GetReadyVars().mSession, delegate); } bool ChannelContext::MatchNodeId(NodeId nodeId) @@ -258,12 +258,13 @@ void ChannelContext::EnterCasePairingState() auto & prepare = GetPrepareVars(); prepare.mCasePairingSession = Platform::New(); - ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), prepare.mCasePairingSession); - VerifyOrReturn(ctxt != nullptr); - // TODO: currently only supports IP/UDP paring Transport::PeerAddress addr; addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress); + + ExchangeContext * ctxt = mExchangeManager->NewUnsecureContext(addr, prepare.mCasePairingSession); + VerifyOrReturn(ctxt != nullptr); + CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, &prepare.mBuilder.GetOperationalCredentialSet(), prepare.mBuilder.GetPeerNodeId(), mExchangeManager->GetNextKeyId(), ctxt, this); diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index dbe17799d4adc0..50954ae58226f0 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -66,7 +66,7 @@ CHIP_ERROR Device::SendMessage(Protocols::Id protocolId, uint8_t msgType, System ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(loadedSecureSession)); - Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(mSecureSession, nullptr); + Messaging::ExchangeContext * exchange = mExchangeMgr->NewSecureContext(mSecureSession, nullptr); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_NO_MEMORY); if (!loadedSecureSession) diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 4a4d0096bb1701..7b857791cb2c67 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -790,10 +790,6 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle); - err = mPairingSession.MessageDispatch().Init(mTransportMgr); - SuccessOrExit(err); - mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); - device->Init(GetControllerDeviceInitParams(), mListenPort, remoteDeviceId, peerAddress, admin->GetAdminId()); mSystemLayer->StartTimer(kSessionEstablishmentTimeout, OnSessionEstablishmentTimeoutCallback, this); @@ -818,7 +814,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam } } #endif - exchangeCtxt = mExchangeMgr->NewContext(SecureSessionHandle(), &mPairingSession); + exchangeCtxt = mExchangeMgr->NewUnsecureContext(params.GetPeerAddress(), &mPairingSession); VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL); err = mPairingSession.Pair(params.GetPeerAddress(), params.GetSetupPINCode(), mNextKeyId++, exchangeCtxt, this); diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp deleted file mode 100644 index 34ef91f7c1e1f9..00000000000000 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ /dev/null @@ -1,74 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file provides implementation of Application Channel class. - */ - -#include -#include - -namespace chip { -namespace Messaging { - -CHIP_ERROR ApplicationExchangeDispatch::SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) -{ - return mSessionMgr->SendMessage(session, payloadHeader, std::move(message), retainedMessage); -} - -CHIP_ERROR ApplicationExchangeDispatch::ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle message, - EncryptedPacketBufferHandle * retainedMessage) const -{ - return mSessionMgr->SendEncryptedMessage(session, std::move(message), retainedMessage); -} - -bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) -{ - // TODO: Change this check to only include the protocol and message types that are allowed - switch (protocol) - { - case Protocols::SecureChannel::Id.GetProtocolId(): - switch (type) - { - case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamRequest): - case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamResponse): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p1): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p2): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p3): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2pError): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR1): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR2): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR3): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaErr): - return false; - - default: - break; - } - break; - - default: - break; - } - return true; -} - -} // namespace Messaging -} // namespace chip diff --git a/src/messaging/ApplicationExchangeDispatch.h b/src/messaging/ApplicationExchangeDispatch.h deleted file mode 100644 index a195ce25423608..00000000000000 --- a/src/messaging/ApplicationExchangeDispatch.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file defines Application Channel class. The object of this - * class can be used by CHIP data model cluster applications to send - * and receive messages. The messages are encrypted using session keys. - */ - -#pragma once - -#include -#include -#include - -namespace chip { -namespace Messaging { - -class ApplicationExchangeDispatch : public ExchangeMessageDispatch -{ -public: - ApplicationExchangeDispatch() {} - - virtual ~ApplicationExchangeDispatch() {} - - CHIP_ERROR Init(ReliableMessageMgr * reliableMessageMgr, SecureSessionMgr * sessionMgr) - { - ReturnErrorCodeIf(sessionMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - mSessionMgr = sessionMgr; - return ExchangeMessageDispatch::Init(reliableMessageMgr); - } - - CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle message, - EncryptedPacketBufferHandle * retainedMessage) const override; - - SecureSessionMgr * GetSessionMgr() const { return mSessionMgr; } - -protected: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override; - - bool MessagePermitted(uint16_t protocol, uint8_t type) override; - -private: - SecureSessionMgr * mSessionMgr = nullptr; -}; - -} // namespace Messaging -} // namespace chip diff --git a/src/messaging/BUILD.gn b/src/messaging/BUILD.gn index 76966ef817eb4b..f146acc49787e0 100644 --- a/src/messaging/BUILD.gn +++ b/src/messaging/BUILD.gn @@ -18,16 +18,12 @@ static_library("messaging") { output_name = "libMessagingLayer" sources = [ - "ApplicationExchangeDispatch.cpp", - "ApplicationExchangeDispatch.h", "ErrorCategory.cpp", "ErrorCategory.h", "ExchangeACL.h", "ExchangeContext.cpp", "ExchangeContext.h", "ExchangeDelegate.h", - "ExchangeMessageDispatch.cpp", - "ExchangeMessageDispatch.h", "ExchangeMgr.cpp", "ExchangeMgr.h", "ExchangeMgrDelegate.h", diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 281834f5d42f82..89838147dbd210 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -81,6 +81,9 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp { CHIP_ERROR err = CHIP_NO_ERROR; Transport::PeerConnectionState * state = nullptr; + bool reliableTransmissionRequested = true; + PayloadHeader payloadHeader; + ReliableMessageContext * reliableMessageContext = GetReliableMessageContext(); VerifyOrReturnError(mExchangeMgr != nullptr, CHIP_ERROR_INTERNAL); @@ -92,26 +95,28 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp // an error arising below. at the end, we have to close it. Retain(); - bool reliableTransmissionRequested = true; - - state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession); - // If sending via UDP and NoAutoRequestAck send flag is not specificed, request reliable transmission. - if (state != nullptr && state->GetPeerAddress().GetTransportType() != Transport::Type::kUdp) - { - reliableTransmissionRequested = false; - } - else { - reliableTransmissionRequested = !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck); - } - - ExchangeMessageDispatch * dispatch = GetMessageDispatch(); - ApplicationExchangeDispatch defaultDispatch; + Transport::Type transportType; + if (IsSecure()) + { + state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(GetVariantSecure().mSecureSession); + VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); + transportType = state->GetPeerAddress().GetTransportType(); + } + else + { + transportType = GetVariantUnsecure().mPeerAddress.GetTransportType(); + } - if (dispatch == nullptr) - { - defaultDispatch.Init(mExchangeMgr->GetReliableMessageMgr(), mExchangeMgr->GetSessionMgr()); - dispatch = &defaultDispatch; + // If sending via UDP and NoAutoRequestAck send flag is not specificed, request reliable transmission. + if (transportType != Transport::Type::kUdp) + { + reliableTransmissionRequested = false; + } + else + { + reliableTransmissionRequested = !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck); + } } // If a response message is expected... @@ -130,8 +135,71 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp } } - err = dispatch->SendMessage(mSecureSession, mExchangeId, IsInitiator(), GetReliableMessageContext(), - reliableTransmissionRequested, protocolId, msgType, std::move(msgBuf)); + payloadHeader.SetExchangeID(mExchangeId).SetMessageType(protocolId, msgType).SetInitiator(IsInitiator()); + + // If there is a pending acknowledgment piggyback it on this message. + if (reliableMessageContext->HasPeerRequestedAck()) + { + payloadHeader.SetAckId(reliableMessageContext->GetPendingPeerAckId()); + + // Set AckPending flag to false since current outgoing message is going to serve as the ack on this exchange. + reliableMessageContext->SetAckPending(false); + +#if !defined(NDEBUG) + if (!payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck)) + { + ChipLogDetail(ExchangeManager, "Piggybacking Ack for MsgId:%08" PRIX32 " with msg", + reliableMessageContext->GetPendingPeerAckId()); + } +#endif + } + + if (IsSecure()) + { + // TODO: CRMP is only enabled for secure messages, it should be enabled for all messages in the future. + if (reliableMessageContext->AutoRequestAck() && reliableTransmissionRequested) + { + payloadHeader.SetNeedsAck(true); + + ReliableMessageMgr::RetransTableEntry * entry = nullptr; + + // Add to Table for subsequent sending + ReturnErrorOnFailure(mExchangeMgr->GetReliableMessageMgr()->AddToRetransTable(reliableMessageContext, &entry)); + + err = mExchangeMgr->GetSessionMgr()->SendMessage(GetVariantSecure().mSecureSession, payloadHeader, std::move(msgBuf), + &entry->retainedBuf); + + if (err != CHIP_NO_ERROR) + { + // Remove from table + ChipLogError(ExchangeManager, "Failed to send message with err %s", ::chip::ErrorStr(err)); + mExchangeMgr->GetReliableMessageMgr()->ClearRetransTable(*entry); + ReturnErrorOnFailure(err); + } + else + { + VerifyOrDie(!entry->retainedBuf.IsNull()); // if send success, then retainedBuf mustn't be empty + mExchangeMgr->GetReliableMessageMgr()->StartRetransmision(entry); + } + } + else + { + // If the channel itself is providing reliability, let's not request CRMP acks + payloadHeader.SetNeedsAck(false); + err = mExchangeMgr->GetSessionMgr()->SendMessage(GetVariantSecure().mSecureSession, payloadHeader, std::move(msgBuf), + nullptr); + } + } + else + { + PacketHeader packetHeader; + + ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(msgBuf)); + ReturnErrorOnFailure(packetHeader.EncodeBeforeData(msgBuf)); + + err = + mExchangeMgr->GetSessionMgr()->GetTransportManager()->SendMessage(GetVariantUnsecure().mPeerAddress, std::move(msgBuf)); + } exit: if (err != CHIP_NO_ERROR && IsResponseExpected()) @@ -150,6 +218,19 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp return err; } +CHIP_ERROR ExchangeContext::ResendMessage(EncryptedPacketBufferHandle message, EncryptedPacketBufferHandle * retainedMessage) +{ + if (IsSecure()) + { + return mExchangeMgr->GetSessionMgr()->SendEncryptedMessage(GetVariantSecure().mSecureSession, std::move(message), + retainedMessage); + } + else + { + return CHIP_ERROR_NOT_IMPLEMENTED; + } +} + void ExchangeContext::DoClose(bool clearRetransTable) { // Clear protocol callbacks @@ -212,32 +293,17 @@ void ExchangeContext::Abort() Release(); } -void ExchangeContext::Reset() +void ExchangeContextDeletor::Release(ExchangeContext * ec) { - *this = ExchangeContext(); + ec->mExchangeMgr->ReleaseContext(ec); } -ExchangeMessageDispatch * ExchangeContext::GetMessageDispatch() +ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, bool Initiator, ExchangeDelegate * delegate) { - if (mDelegate != nullptr) - { - return mDelegate->GetMessageDispatch(mExchangeMgr->GetReliableMessageMgr(), mExchangeMgr->GetSessionMgr()); - } + VerifyOrDie(mExchangeMgr == nullptr); - return nullptr; -} - -ExchangeContext * ExchangeContext::Alloc(ExchangeManager * em, uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, - ExchangeDelegateBase * delegate) -{ - VerifyOrDie(mExchangeMgr == nullptr && GetReferenceCount() == 0); - - Reset(); - Retain(); mExchangeMgr = em; - em->IncrementContextsInUse(); - mExchangeId = ExchangeId; - mSecureSession = session; + mExchangeId = ExchangeId; mFlags.Set(Flags::kFlagInitiator, Initiator); mDelegate = delegate; @@ -248,43 +314,45 @@ ExchangeContext * ExchangeContext::Alloc(ExchangeManager * em, uint16_t Exchange SetAutoRequestAck(true); #if defined(CHIP_EXCHANGE_CONTEXT_DETAIL_LOGGING) - ChipLogDetail(ExchangeManager, "ec++ id: %d, inUse: %d, addr: 0x%x", (this - em->mContextPool.begin()), em->GetContextsInUse(), - this); + ChipLogDetail(ExchangeManager, "ec++ id: %d", ExchangeId); #endif SYSTEM_STATS_INCREMENT(chip::System::Stats::kExchangeMgr_NumContexts); +} + +ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, + ExchangeDelegate * delegate) : + ExchangeContext(em, ExchangeId, Initiator, delegate) +{ + mFlags.Set(Flags::kIsSecure, true); + mSession.Set(session); +} - return this; +ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, const Transport::PeerAddress & peerAddress, + bool Initiator, ExchangeDelegate * delegate) : + ExchangeContext(em, ExchangeId, Initiator, delegate) +{ + mFlags.Set(Flags::kIsSecure, false); + mSession.Set(peerAddress); } -void ExchangeContext::Free() +ExchangeContext::~ExchangeContext() { VerifyOrDie(mExchangeMgr != nullptr && GetReferenceCount() == 0); // Ideally, in this scenario, the retransmit table should // be clear of any outstanding messages for this context and // the boolean parameter passed to DoClose() should not matter. - ExchangeManager * em = mExchangeMgr; DoClose(false); mExchangeMgr = nullptr; - em->DecrementContextsInUse(); - - if (mExchangeACL != nullptr) - { - chip::Platform::Delete(mExchangeACL); - mExchangeACL = nullptr; - } - #if defined(CHIP_EXCHANGE_CONTEXT_DETAIL_LOGGING) - ChipLogDetail(ExchangeManager, "ec-- id: %d [%04" PRIX16 "], inUse: %d, addr: 0x%x", (this - em->mContextPool.begin()), - mExchangeId, em->GetContextsInUse(), this); + ChipLogDetail(ExchangeManager, "ec-- id: %d", mExchangeId); #endif SYSTEM_STATS_DECREMENT(chip::System::Stats::kExchangeMgr_NumContexts); } -bool ExchangeContext::MatchExchange(SecureSessionHandle session, const PacketHeader & packetHeader, - const PayloadHeader & payloadHeader) +bool ExchangeContext::MatchExchange(const PayloadHeader & payloadHeader) { // A given message is part of a particular exchange if... return @@ -292,19 +360,39 @@ bool ExchangeContext::MatchExchange(SecureSessionHandle session, const PacketHea // The exchange identifier of the message matches the exchange identifier of the context. (mExchangeId == payloadHeader.GetExchangeID()) + // AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false) + // OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast + // case, the initiator is ill defined) + + && (payloadHeader.IsInitiator() != IsInitiator()); +} + +bool ExchangeContext::MatchSecureExchange(SecureSessionHandle session, const PacketHeader & packetHeader, + const PayloadHeader & payloadHeader) +{ + // A given message is part of a particular exchange if... + return MatchExchange(payloadHeader) + + && IsSecure() + // AND The message was received from the peer node associated with the exchange - && (mSecureSession == session) + && (GetSecureSession() == session) // AND The message's source Node ID matches the peer Node ID associated with the exchange, or the peer Node ID of the // exchange is 'any'. - && ((mSecureSession.GetPeerNodeId() == kAnyNodeId) || - (packetHeader.GetSourceNodeId().HasValue() && mSecureSession.GetPeerNodeId() == packetHeader.GetSourceNodeId().Value())) + && ((GetSecureSession().GetPeerNodeId() == kAnyNodeId) || + (packetHeader.GetSourceNodeId().HasValue() && + GetSecureSession().GetPeerNodeId() == packetHeader.GetSourceNodeId().Value())); +} - // AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false) - // OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast - // case, the initiator is ill defined) +bool ExchangeContext::MatchUnsecureExchange(const Transport::PeerAddress & peerAddress, const PayloadHeader & payloadHeader) +{ + // A given message is part of a particular exchange if... + return MatchExchange(payloadHeader) - && (payloadHeader.IsInitiator() != IsInitiator()); + && !IsSecure() + + && peerAddress == GetVariantUnsecure().mPeerAddress; } CHIP_ERROR ExchangeContext::StartResponseTimer() @@ -341,7 +429,7 @@ void ExchangeContext::HandleResponseTimeout(System::Layer * aSystemLayer, void * // NOTE: we don't set mResponseExpected to false here because the response could still arrive. If the user // wants to never receive the response, they must close the exchange context. - ExchangeDelegateBase * delegate = ec->GetDelegate(); + ExchangeDelegate * delegate = ec->GetDelegate(); // Call the user's timeout handler. if (delegate != nullptr) @@ -351,24 +439,33 @@ void ExchangeContext::HandleResponseTimeout(System::Layer * aSystemLayer, void * CHIP_ERROR ExchangeContext::HandleMessage(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const Transport::PeerAddress & peerAddress, PacketBufferHandle msgBuf) { + CHIP_ERROR err = CHIP_NO_ERROR; + // We hold a reference to the ExchangeContext here to // guard against Close() calls(decrementing the reference // count) by the protocol before the CHIP Exchange // layer has completed its work on the ExchangeContext. Retain(); - ExchangeMessageDispatch * dispatch = GetMessageDispatch(); - ApplicationExchangeDispatch defaultDispatch; - - if (dispatch == nullptr) + ReliableMessageContext * reliableMessageContext = GetReliableMessageContext(); + if (payloadHeader.IsAckMsg() && payloadHeader.GetAckId().HasValue()) { - defaultDispatch.Init(mExchangeMgr->GetReliableMessageMgr(), mExchangeMgr->GetSessionMgr()); - dispatch = &defaultDispatch; + SuccessOrExit(err = reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckId().Value())); } - CHIP_ERROR err = - dispatch->OnMessageReceived(payloadHeader, packetHeader.GetMessageId(), peerAddress, GetReliableMessageContext()); - SuccessOrExit(err); + if (payloadHeader.NeedsAck()) + { + MessageFlags msgFlags; + + // An acknowledgment needs to be sent back to the peer for this message on this exchange, + // Set the flag in message header indicating an ack requested by peer; + msgFlags.Set(MessageFlagValues::kPeerRequestedAck); + + // Also set the flag in the exchange context indicating an ack requested; + reliableMessageContext->SetPeerRequestedAck(true); + + SuccessOrExit(err = reliableMessageContext->HandleNeedsAck(packetHeader.GetMessageId(), msgFlags)); + } // The SecureChannel::StandaloneAck message type is only used for CRMP; do not pass such messages to the application layer. if (payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck)) diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index d58e140c6528fd..bee0550a8bb90a 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -54,8 +55,7 @@ class ExchangeContextDeletor * It defines methods for encoding and communicating CHIP messages within an ExchangeContext * over various transport mechanisms, for example, TCP, UDP, or CHIP Reliable Messaging. */ -class DLL_EXPORT ExchangeContext : public ReliableMessageContext, - public ReferenceCounted +class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public ReferenceCounted { friend class ExchangeManager; friend class ExchangeContextDeletor; @@ -63,6 +63,16 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public: typedef uint32_t Timeout; // Type used to express the timeout in this ExchangeContext, in milliseconds + // Create a secure ExchangeContext + ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, + ExchangeDelegate * delegate); + + // Create an unsecure ExchangeContext + ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, const Transport::PeerAddress & peerAddress, bool Initiator, + ExchangeDelegate * delegate); + + ~ExchangeContext(); + /** * Determine whether the context is the initiator of the exchange. * @@ -70,6 +80,11 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, */ bool IsInitiator() const; + /** + * Determine whether the context is secure. + */ + bool IsSecure() const { return mFlags.Has(Flags::kIsSecure); } + /** * Send a CHIP message on this exchange. * @@ -105,6 +120,11 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, std::move(msgPayload), sendFlags); } + /** + * Resend a message which is already encrypted, used by CRMP retrans + */ + CHIP_ERROR ResendMessage(EncryptedPacketBufferHandle message, EncryptedPacketBufferHandle * retainedMessage); + /** * Handle a received CHIP message on this exchange. * @@ -124,30 +144,35 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, CHIP_ERROR HandleMessage(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle msgBuf); - ExchangeDelegateBase * GetDelegate() const { return mDelegate; } - void SetDelegate(ExchangeDelegateBase * delegate) { mDelegate = delegate; } + ExchangeDelegate * GetDelegate() const { return mDelegate; } + void SetDelegate(ExchangeDelegate * delegate) { mDelegate = delegate; } ExchangeManager * GetExchangeMgr() const { return mExchangeMgr; } ReliableMessageContext * GetReliableMessageContext() { return static_cast(this); }; - ExchangeMessageDispatch * GetMessageDispatch(); - ExchangeACL * GetExchangeACL(Transport::AdminPairingTable & table) { - if (mExchangeACL == nullptr) + if (IsSecure()) { - Transport::AdminPairingInfo * admin = table.FindAdminWithId(mSecureSession.GetAdminId()); - if (admin != nullptr) + if (GetVariantSecure().mExchangeACL == nullptr) { - mExchangeACL = chip::Platform::New(admin); + Transport::AdminPairingInfo * admin = table.FindAdminWithId(GetVariantSecure().mSecureSession.GetAdminId()); + if (admin != nullptr) + { + GetVariantSecure().mExchangeACL = chip::Platform::New(admin); + } } - } - return mExchangeACL; + return GetVariantSecure().mExchangeACL; + } + else + { + return nullptr; + } } - SecureSessionHandle GetSecureSession() { return mSecureSession; } + SecureSessionHandle GetSecureSession() { return GetVariantSecure().mSecureSession; } uint16_t GetExchangeId() const { return mExchangeId; } @@ -163,17 +188,53 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, private: Timeout mResponseTimeout; // Maximum time to wait for response (in milliseconds); 0 disables response timeout. - ExchangeDelegateBase * mDelegate = nullptr; - ExchangeManager * mExchangeMgr = nullptr; - ExchangeACL * mExchangeACL = nullptr; + ExchangeDelegate * mDelegate = nullptr; + ExchangeManager * mExchangeMgr = nullptr; + + uint16_t mExchangeId; // Assigned exchange ID. + + struct SecureSession + { + static constexpr const size_t VariantId = 1; + + SecureSession(SecureSessionHandle session) : mExchangeACL(nullptr), mSecureSession(session) {} + + ~SecureSession() + { + if (mExchangeACL != nullptr) + { + chip::Platform::Delete(mExchangeACL); + mExchangeACL = nullptr; + } + } + + ExchangeACL * mExchangeACL; + const SecureSessionHandle mSecureSession; + }; + + struct UnsecureSession + { + static constexpr const size_t VariantId = 2; + UnsecureSession(const Transport::PeerAddress & peerAddress) : mPeerAddress(peerAddress) {} + const Transport::PeerAddress mPeerAddress; + }; + + Variant mSession; - SecureSessionHandle mSecureSession; // The connection state - uint16_t mExchangeId; // Assigned exchange ID. + SecureSession & GetVariantSecure() + { + assert(IsSecure()); + return mSession.Get(); + } + + UnsecureSession & GetVariantUnsecure() + { + assert(!IsSecure()); + return mSession.Get(); + } - ExchangeContext * Alloc(ExchangeManager * em, uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, - ExchangeDelegateBase * delegate); - void Free(); - void Reset(); + // Base constructor + ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, bool Initiator, ExchangeDelegate * delegate); /** * Determine whether a response is currently expected for a message that was sent over @@ -194,19 +255,9 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, */ void SetResponseExpected(bool inResponseExpected); - /** - * Search for an existing exchange that the message applies to. - * - * @param[in] session The secure session of the received message. - * - * @param[in] packetHeader A reference to the PacketHeader object. - * - * @param[in] payloadHeader A reference to the PayloadHeader object. - * - * @retval true If a match is found. - * @retval false If a match is not found. - */ - bool MatchExchange(SecureSessionHandle session, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader); + bool MatchExchange(const PayloadHeader & payloadHeader); + bool MatchSecureExchange(SecureSessionHandle session, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader); + bool MatchUnsecureExchange(const Transport::PeerAddress & peerAddress, const PayloadHeader & payloadHeader); CHIP_ERROR StartResponseTimer(); @@ -216,10 +267,5 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, void DoClose(bool clearRetransTable); }; -inline void ExchangeContextDeletor::Release(ExchangeContext * obj) -{ - obj->Free(); -} - } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeDelegate.h b/src/messaging/ExchangeDelegate.h index dd317234af0eed..4db530ad7dc71f 100644 --- a/src/messaging/ExchangeDelegate.h +++ b/src/messaging/ExchangeDelegate.h @@ -23,8 +23,6 @@ #pragma once -#include -#include #include #include #include @@ -42,10 +40,10 @@ class ExchangeContext; * is interested in receiving these callbacks, they can specialize this class and handle * each trigger in their implementation of this class. */ -class DLL_EXPORT ExchangeDelegateBase +class DLL_EXPORT ExchangeDelegate { public: - virtual ~ExchangeDelegateBase() {} + virtual ~ExchangeDelegate() {} /** * @brief @@ -76,23 +74,6 @@ class DLL_EXPORT ExchangeDelegateBase * @param[in] ec A pointer to the ExchangeContext object. */ virtual void OnExchangeClosing(ExchangeContext * ec) {} - - virtual ExchangeMessageDispatch * GetMessageDispatch(ReliableMessageMgr * rmMgr, SecureSessionMgr * sessionMgr) = 0; -}; - -class DLL_EXPORT ExchangeDelegate : public ExchangeDelegateBase -{ -public: - virtual ~ExchangeDelegate() {} - - virtual ExchangeMessageDispatch * GetMessageDispatch(ReliableMessageMgr * rmMgr, SecureSessionMgr * sessionMgr) - { - mMessageDispatch.Init(rmMgr, sessionMgr); - return &mMessageDispatch; - } - -private: - ApplicationExchangeDispatch mMessageDispatch; }; } // namespace Messaging diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp deleted file mode 100644 index 5afc11d71cd525..00000000000000 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ /dev/null @@ -1,134 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file provides implementation of ExchangeMessageDispatch class. - */ - -#ifndef __STDC_FORMAT_MACROS -#define __STDC_FORMAT_MACROS -#endif - -#ifndef __STDC_LIMIT_MACROS -#define __STDC_LIMIT_MACROS -#endif - -#include - -#include -#include -#include -#include -#include - -namespace chip { -namespace Messaging { - -CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uint16_t exchangeId, bool isInitiator, - ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, - Protocols::Id protocol, uint8_t type, System::PacketBufferHandle message) -{ - ReturnErrorCodeIf(!MessagePermitted(protocol.GetProtocolId(), type), CHIP_ERROR_INVALID_ARGUMENT); - - PayloadHeader payloadHeader; - payloadHeader.SetExchangeID(exchangeId).SetMessageType(protocol, type).SetInitiator(isInitiator); - - // If there is a pending acknowledgment piggyback it on this message. - if (reliableMessageContext->HasPeerRequestedAck()) - { - payloadHeader.SetAckId(reliableMessageContext->GetPendingPeerAckId()); - - // Set AckPending flag to false since current outgoing message is going to serve as the ack on this exchange. - reliableMessageContext->SetAckPending(false); - -#if !defined(NDEBUG) - if (!payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck)) - { - ChipLogDetail(ExchangeManager, "Piggybacking Ack for MsgId:%08" PRIX32 " with msg", - reliableMessageContext->GetPendingPeerAckId()); - } -#endif - } - - if (IsReliableTransmissionAllowed() && reliableMessageContext->AutoRequestAck() && mReliableMessageMgr != nullptr && - isReliableTransmission) - { - payloadHeader.SetNeedsAck(true); - - ReliableMessageMgr::RetransTableEntry * entry = nullptr; - - // Add to Table for subsequent sending - ReturnErrorOnFailure(mReliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry)); - - CHIP_ERROR err = SendMessageImpl(session, payloadHeader, std::move(message), &entry->retainedBuf); - if (err != CHIP_NO_ERROR) - { - // Remove from table - ChipLogError(ExchangeManager, "Failed to send message with err %s", ::chip::ErrorStr(err)); - mReliableMessageMgr->ClearRetransTable(*entry); - ReturnErrorOnFailure(err); - } - else - { - mReliableMessageMgr->StartRetransmision(entry); - } - } - else - { - // If the channel itself is providing reliability, let's not request CRMP acks - payloadHeader.SetNeedsAck(false); - ReturnErrorOnFailure(SendMessageImpl(session, payloadHeader, std::move(message), nullptr)); - } - - return CHIP_NO_ERROR; -} - -CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, - const Transport::PeerAddress & peerAddress, - ReliableMessageContext * reliableMessageContext) -{ - ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()), - CHIP_ERROR_INVALID_ARGUMENT); - - if (IsReliableTransmissionAllowed()) - { - if (payloadHeader.IsAckMsg() && payloadHeader.GetAckId().HasValue()) - { - ReturnErrorOnFailure(reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckId().Value())); - } - - if (payloadHeader.NeedsAck()) - { - MessageFlags msgFlags; - - // An acknowledgment needs to be sent back to the peer for this message on this exchange, - // Set the flag in message header indicating an ack requested by peer; - msgFlags.Set(MessageFlagValues::kPeerRequestedAck); - - // Also set the flag in the exchange context indicating an ack requested; - reliableMessageContext->SetPeerRequestedAck(true); - - ReturnErrorOnFailure(reliableMessageContext->HandleNeedsAck(messageId, msgFlags)); - } - } - - return CHIP_NO_ERROR; -} - -} // namespace Messaging -} // namespace chip diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h deleted file mode 100644 index ef2cb596b28445..00000000000000 --- a/src/messaging/ExchangeMessageDispatch.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file defines ExchangeMessageDispatch class. The object of this - * class can be used by CHIP protocols to send and receive messages. - */ - -#pragma once - -#include - -namespace chip { -namespace Messaging { - -class ReliableMessageMgr; -class ReliableMessageContext; - -class ExchangeMessageDispatch -{ -public: - ExchangeMessageDispatch() {} - virtual ~ExchangeMessageDispatch() {} - - CHIP_ERROR Init(ReliableMessageMgr * reliableMessageMgr) - { - mReliableMessageMgr = reliableMessageMgr; - return CHIP_NO_ERROR; - } - - CHIP_ERROR SendMessage(SecureSessionHandle session, uint16_t exchangeId, bool isInitiator, - ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol, - uint8_t type, System::PacketBufferHandle message); - - virtual CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle message, - EncryptedPacketBufferHandle * retainedMessage) const - { - return CHIP_ERROR_NOT_IMPLEMENTED; - } - - virtual CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, - const Transport::PeerAddress & peerAddress, - ReliableMessageContext * reliableMessageContext); - -protected: - virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; - - virtual CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) = 0; - - virtual bool IsReliableTransmissionAllowed() { return true; } - -private: - ReliableMessageMgr * mReliableMessageMgr = nullptr; -}; - -} // namespace Messaging -} // namespace chip diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 0c5fe415f0edb6..4284c6bf912a70 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -75,8 +75,6 @@ CHIP_ERROR ExchangeManager::Init(SecureSessionMgr * sessionMgr) mNextExchangeId = GetRandU16(); mNextKeyId = 0; - mContextsInUse = 0; - for (auto & handler : UMHandlerPool) { // Mark all handlers as unallocated. This handles both initial @@ -98,11 +96,11 @@ CHIP_ERROR ExchangeManager::Shutdown() { mReliableMessageMgr.Shutdown(); - for (auto & ec : mContextPool) - { - // ExchangeContext leaked - assert(ec.GetReferenceCount() == 0); - } + mContextPool.ForEachActiveObject([](auto * ec) { + // There should be no active object in the pool + assert(false); + return true; + }); if (mSessionMgr != nullptr) { @@ -115,18 +113,23 @@ CHIP_ERROR ExchangeManager::Shutdown() return CHIP_NO_ERROR; } -ExchangeContext * ExchangeManager::NewContext(SecureSessionHandle session, ExchangeDelegateBase * delegate) +ExchangeContext * ExchangeManager::NewSecureContext(SecureSessionHandle session, ExchangeDelegate * delegate) { - return AllocContext(mNextExchangeId++, session, true, delegate); + return mContextPool.CreateObject(this, mNextExchangeId++, session, true, delegate); } -CHIP_ERROR ExchangeManager::RegisterUnsolicitedMessageHandlerForProtocol(Protocols::Id protocolId, ExchangeDelegateBase * delegate) +ExchangeContext * ExchangeManager::NewUnsecureContext(Transport::PeerAddress peerAddress, ExchangeDelegate * delegate) +{ + return mContextPool.CreateObject(this, mNextExchangeId++, peerAddress, true, delegate); +} + +CHIP_ERROR ExchangeManager::RegisterUnsolicitedMessageHandlerForProtocol(Protocols::Id protocolId, ExchangeDelegate * delegate) { return RegisterUMH(protocolId, kAnyMessageType, delegate); } CHIP_ERROR ExchangeManager::RegisterUnsolicitedMessageHandlerForType(Protocols::Id protocolId, uint8_t msgType, - ExchangeDelegateBase * delegate) + ExchangeDelegate * delegate) { return RegisterUMH(protocolId, static_cast(msgType), delegate); } @@ -146,24 +149,7 @@ void ExchangeManager::OnReceiveError(CHIP_ERROR error, const Transport::PeerAddr ChipLogError(ExchangeManager, "Accept FAILED, err = %s", ErrorStr(error)); } -ExchangeContext * ExchangeManager::AllocContext(uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, - ExchangeDelegateBase * delegate) -{ - CHIP_FAULT_INJECT(FaultInjection::kFault_AllocExchangeContext, return nullptr); - - for (auto & ec : mContextPool) - { - if (ec.GetReferenceCount() == 0) - { - return ec.Alloc(this, ExchangeId, session, Initiator, delegate); - } - } - - ChipLogError(ExchangeManager, "Alloc ctxt FAILED"); - return nullptr; -} - -CHIP_ERROR ExchangeManager::RegisterUMH(Protocols::Id protocolId, int16_t msgType, ExchangeDelegateBase * delegate) +CHIP_ERROR ExchangeManager::RegisterUMH(Protocols::Id protocolId, int16_t msgType, ExchangeDelegate * delegate) { UnsolicitedMessageHandler * umh = UMHandlerPool; UnsolicitedMessageHandler * selected = nullptr; @@ -211,9 +197,9 @@ CHIP_ERROR ExchangeManager::UnregisterUMH(Protocols::Id protocolId, int16_t msgT return CHIP_ERROR_NO_UNSOLICITED_MESSAGE_HANDLER; } -void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, - SecureSessionHandle session, const Transport::PeerAddress & source, - System::PacketBufferHandle msgBuf, SecureSessionMgr * msgLayer) +void ExchangeManager::OnSecureMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + SecureSessionHandle session, const Transport::PeerAddress & source, + System::PacketBufferHandle msgBuf, SecureSessionMgr * msgLayer) { CHIP_ERROR err = CHIP_NO_ERROR; UnsolicitedMessageHandler * umh = nullptr; @@ -224,22 +210,28 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const payloadHeader.GetProtocolID()); // Search for an existing exchange that the message applies to. If a match is found... - for (auto & ec : mContextPool) - { - if (ec.GetReferenceCount() > 0 && ec.MatchExchange(session, packetHeader, payloadHeader)) + bool found = false; + mContextPool.ForEachActiveObject([&](auto * ec) { + if (ec->MatchSecureExchange(session, packetHeader, payloadHeader)) { // Found a matching exchange. Set flag for correct subsequent CRMP // retransmission timeout selection. - if (!ec.HasRcvdMsgFromPeer()) + if (!ec->HasRcvdMsgFromPeer()) { - ec.SetMsgRcvdFromPeer(true); + ec->SetMsgRcvdFromPeer(true); } // Matched ExchangeContext; send to message handler. - ec.HandleMessage(packetHeader, payloadHeader, source, std::move(msgBuf)); - - ExitNow(err = CHIP_NO_ERROR); + ec->HandleMessage(packetHeader, payloadHeader, source, std::move(msgBuf)); + found = true; + return false; } + return true; + }); + + if (found) + { + ExitNow(err = CHIP_NO_ERROR); } // Search for an unsolicited message handler if it marked as being sent by an initiator. Since we didn't @@ -289,17 +281,16 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const // If rcvd msg is from initiator then this exchange is created as not Initiator. // If rcvd msg is not from initiator then this exchange is created as Initiator. // TODO: Figure out which channel to use for the received message - ec = AllocContext(payloadHeader.GetExchangeID(), session, !payloadHeader.IsInitiator(), nullptr); + ec = mContextPool.CreateObject(this, payloadHeader.GetExchangeID(), session, !payloadHeader.IsInitiator(), nullptr); } else { - ec = AllocContext(payloadHeader.GetExchangeID(), session, false, matchingUMH->Delegate); + ec = mContextPool.CreateObject(this, payloadHeader.GetExchangeID(), session, false, matchingUMH->Delegate); } VerifyOrExit(ec != nullptr, err = CHIP_ERROR_NO_MEMORY); - ChipLogDetail(ExchangeManager, "ec pos: %d, id: %d, Delegate: 0x%x", ec - mContextPool.begin(), ec->GetExchangeId(), - ec->GetDelegate()); + ChipLogDetail(ExchangeManager, "ec id: %d, Delegate: 0x%x", ec->GetExchangeId(), ec->GetDelegate()); ec->HandleMessage(packetHeader, payloadHeader, source, std::move(msgBuf)); @@ -330,67 +321,140 @@ void ExchangeManager::OnConnectionExpired(SecureSessionHandle session, SecureSes mDelegate->OnConnectionExpired(session, this); } - for (auto & ec : mContextPool) - { - if (ec.GetReferenceCount() > 0 && ec.mSecureSession == session) + mContextPool.ForEachActiveObject([&](auto * ec) { + if (ec->IsSecure() && ec->GetSecureSession() == session) { - ec.Close(); + ec->Close(); // Continue iterate because there can be multiple contexts associated with the connection. } - } + return true; + }); } -void ExchangeManager::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf) +void ExchangeManager::OnUnsecureMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, + SecureSessionMgr * mgr) { - PacketHeader header; + CHIP_ERROR err = CHIP_NO_ERROR; + UnsolicitedMessageHandler * umh = nullptr; + UnsolicitedMessageHandler * matchingUMH = nullptr; + bool sendAckAndCloseExchange = false; - ReturnOnFailure(header.DecodeAndConsume(msgBuf)); + ChipLogProgress(ExchangeManager, "Received unsecure message of type %d and protocolId %d", payloadHeader.GetMessageType(), + payloadHeader.GetProtocolID()); - Optional peer = header.GetSourceNodeId(); - if (!peer.HasValue()) + // Search for an existing exchange that the message applies to. If a match is found... + bool found = false; + mContextPool.ForEachActiveObject([&](auto * ec) { + if (ec->MatchUnsecureExchange(source, payloadHeader)) + { + // Found a matching exchange. Set flag for correct subsequent CRMP + // retransmission timeout selection. + if (!ec->HasRcvdMsgFromPeer()) + { + ec->SetMsgRcvdFromPeer(true); + } + + // Matched ExchangeContext; send to message handler. + ec->HandleMessage(packetHeader, payloadHeader, source, std::move(msgBuf)); + found = true; + return false; + } + return true; + }); + + if (found) { - char addrBuffer[Transport::PeerAddress::kMaxToStringSize]; - source.ToString(addrBuffer, sizeof(addrBuffer)); - ChipLogError(ExchangeManager, "Unencrypted message from %s is dropped since no source node id in packet header.", - addrBuffer); - return; + ExitNow(err = CHIP_NO_ERROR); } -} -void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegateBase * delegate) -{ - for (auto & ec : mContextPool) + // Search for an unsolicited message handler if it marked as being sent by an initiator. Since we didn't + // find an existing exchange that matches the message, it must be an unsolicited message. However all + // unsolicited messages must be marked as being from an initiator. + if (payloadHeader.IsInitiator()) { - if (ec.GetReferenceCount() == 0 || ec.GetDelegate() != delegate) + // Search for an unsolicited message handler that can handle the message. Prefer handlers that can explicitly + // handle the message type over handlers that handle all messages for a profile. + umh = (UnsolicitedMessageHandler *) UMHandlerPool; + + matchingUMH = nullptr; + + for (int i = 0; i < CHIP_CONFIG_MAX_UNSOLICITED_MESSAGE_HANDLERS; i++, umh++) { - continue; - } + if (umh->IsInUse() && payloadHeader.HasProtocol(umh->ProtocolId)) + { + if (umh->MessageType == payloadHeader.GetMessageType()) + { + matchingUMH = umh; + break; + } - // Make sure to null out the delegate before closing the context, so - // we don't notify the delegate that the context is closing. We - // have to do this, because the delegate might be partially - // destroyed by this point. - ec.SetDelegate(nullptr); - ec.Close(); + if (umh->MessageType == kAnyMessageType) + matchingUMH = umh; + } + } + } + // Discard the message if it isn't marked as being sent by an initiator and the message does not need to send + // an ack to the peer. + else if (!payloadHeader.NeedsAck()) + { + ExitNow(err = CHIP_ERROR_UNSOLICITED_MSG_NO_ORIGINATOR); } -} -void ExchangeManager::IncrementContextsInUse() -{ - mContextsInUse++; -} + // If we didn't find an existing exchange that matches the message, and no unsolicited message handler registered + // to hand this message, we need to create a temporary exchange to send an ack for this message and then close this exchange. + sendAckAndCloseExchange = payloadHeader.NeedsAck() && (matchingUMH == nullptr); -void ExchangeManager::DecrementContextsInUse() -{ - if (mContextsInUse >= 1) + // If we found a handler or we need to create a new exchange context (EC). + if (matchingUMH != nullptr || sendAckAndCloseExchange) { - mContextsInUse--; + ExchangeContext * ec = nullptr; + + if (sendAckAndCloseExchange) + { + // If rcvd msg is from initiator then this exchange is created as not Initiator. + // If rcvd msg is not from initiator then this exchange is created as Initiator. + // TODO: Figure out which channel to use for the received message + ec = mContextPool.CreateObject(this, payloadHeader.GetExchangeID(), source, !payloadHeader.IsInitiator(), nullptr); + } + else + { + ec = mContextPool.CreateObject(this, payloadHeader.GetExchangeID(), source, false, matchingUMH->Delegate); + } + + VerifyOrExit(ec != nullptr, err = CHIP_ERROR_NO_MEMORY); + + ChipLogDetail(ExchangeManager, "ec id: %d, Delegate: 0x%x", ec->GetExchangeId(), ec->GetDelegate()); + + ec->HandleMessage(packetHeader, payloadHeader, source, std::move(msgBuf)); + + // Close exchange if it was created only to send ack for a duplicate message. + if (sendAckAndCloseExchange) + ec->Close(); } - else + +exit: + if (err != CHIP_NO_ERROR) { - ChipLogError(ExchangeManager, "No context in use, decrement failed"); + ChipLogError(ExchangeManager, "OnUnsecureMessageReceived failed, err = %s", ErrorStr(err)); } } +void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * delegate) +{ + mContextPool.ForEachActiveObject([&](auto * ec) { + if (ec->GetDelegate() == delegate) + { + // Make sure to null out the delegate before closing the context, so + // we don't notify the delegate that the context is closing. We + // have to do this, because the delegate might be partially + // destroyed by this point. + ec->SetDelegate(nullptr); + ec->Close(); + } + return true; + }); +} + } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index 930d233e643469..4a7cf9340d6cde 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -38,7 +39,7 @@ namespace chip { namespace Messaging { class ExchangeContext; -class ExchangeDelegateBase; +class ExchangeDelegate; static constexpr int16_t kAnyMessageType = -1; @@ -48,7 +49,7 @@ static constexpr int16_t kAnyMessageType = -1; * It works on be behalf of higher layers, creating ExchangeContexts and * handling the registration/unregistration of unsolicited message handlers. */ -class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public TransportMgrDelegate +class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate { friend class ExchangeContext; @@ -87,18 +88,29 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans CHIP_ERROR Shutdown(); /** - * Creates a new ExchangeContext with a given peer CHIP node specified by the peer node identifier. - * - * @param[in] session The identifier of the secure session (possibly - * the empty session for a non-secure exchange) - * for which the ExchangeContext is being set up. + * Creates a new secure ExchangeContext with a given secure session * + * @param[in] session The identifier of the secure session for which the ExchangeContext is being set up. * @param[in] delegate A pointer to ExchangeDelegate. * * @return A pointer to the created ExchangeContext object On success. Otherwise NULL if no object * can be allocated or is available. */ - ExchangeContext * NewContext(SecureSessionHandle session, ExchangeDelegateBase * delegate); + ExchangeContext * NewSecureContext(SecureSessionHandle session, ExchangeDelegate * delegate); + + /** + * Creates a new unsecure ExchangeContext with a given address + * + * @param[in] peerAddress The peer address with which the ExchangeContext is being set up. + * + * @param[in] delegate A pointer to ExchangeDelegate. + * + * @return A pointer to the created ExchangeContext object On success. Otherwise NULL if no object + * can be allocated or is available. + */ + ExchangeContext * NewUnsecureContext(Transport::PeerAddress peerAddress, ExchangeDelegate * delegate); + + void ReleaseContext(ExchangeContext * ec) { mContextPool.ReleaseObject(ec); } /** * Register an unsolicited message handler for a given protocol identifier. This handler would be @@ -112,7 +124,7 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans * is full and a new one cannot be allocated. * @retval #CHIP_NO_ERROR On success. */ - CHIP_ERROR RegisterUnsolicitedMessageHandlerForProtocol(Protocols::Id protocolId, ExchangeDelegateBase * delegate); + CHIP_ERROR RegisterUnsolicitedMessageHandlerForProtocol(Protocols::Id protocolId, ExchangeDelegate * delegate); /** * Register an unsolicited message handler for a given protocol identifier and message type. @@ -127,13 +139,13 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans * is full and a new one cannot be allocated. * @retval #CHIP_NO_ERROR On success. */ - CHIP_ERROR RegisterUnsolicitedMessageHandlerForType(Protocols::Id protocolId, uint8_t msgType, ExchangeDelegateBase * delegate); + CHIP_ERROR RegisterUnsolicitedMessageHandlerForType(Protocols::Id protocolId, uint8_t msgType, ExchangeDelegate * delegate); /** * A strongly-message-typed version of RegisterUnsolicitedMessageHandlerForType. */ template ::value>> - CHIP_ERROR RegisterUnsolicitedMessageHandlerForType(MessageType msgType, ExchangeDelegateBase * delegate) + CHIP_ERROR RegisterUnsolicitedMessageHandlerForType(MessageType msgType, ExchangeDelegate * delegate) { static_assert(std::is_same, uint8_t>::value, "Enum is wrong size; cast is not safe"); return RegisterUnsolicitedMessageHandlerForType(Protocols::MessageTypeTraits::ProtocolId(), @@ -180,10 +192,7 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans * their delegate. To be used if the delegate is being destroyed. This * method will guarantee that it does not call into the delegate. */ - void CloseAllContextsForDelegate(const ExchangeDelegateBase * delegate); - - void IncrementContextsInUse(); - void DecrementContextsInUse(); + void CloseAllContextsForDelegate(const ExchangeDelegate * delegate); void SetDelegate(ExchangeMgrDelegate * delegate) { mDelegate = delegate; } @@ -194,7 +203,6 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans Transport::AdminId GetAdminId() { return mAdminId; } uint16_t GetNextKeyId() { return ++mNextKeyId; } - size_t GetContextsInUse() const { return mContextsInUse; } private: enum class State @@ -215,7 +223,7 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans return ProtocolId == aProtocolId && MessageType == aMessageType; } - ExchangeDelegateBase * Delegate; + ExchangeDelegate * Delegate; Protocols::Id ProtocolId; // Message types are normally 8-bit unsigned ints, but we use // kAnyMessageType, which is negative, to represent a wildcard handler, @@ -234,28 +242,24 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans Transport::AdminId mAdminId = 0; - std::array mContextPool; - size_t mContextsInUse; + BitMapObjectPool mContextPool; UnsolicitedMessageHandler UMHandlerPool[CHIP_CONFIG_MAX_UNSOLICITED_MESSAGE_HANDLERS]; - ExchangeContext * AllocContext(uint16_t ExchangeId, SecureSessionHandle session, bool Initiator, - ExchangeDelegateBase * delegate); - - CHIP_ERROR RegisterUMH(Protocols::Id protocolId, int16_t msgType, ExchangeDelegateBase * delegate); + CHIP_ERROR RegisterUMH(Protocols::Id protocolId, int16_t msgType, ExchangeDelegate * delegate); CHIP_ERROR UnregisterUMH(Protocols::Id protocolId, int16_t msgType); void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgr * msgLayer) override; - void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, SecureSessionHandle session, - const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, - SecureSessionMgr * msgLayer) override; + void OnSecureMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + SecureSessionHandle session, const Transport::PeerAddress & source, + System::PacketBufferHandle msgBuf, SecureSessionMgr * msgLayer) override; + void OnUnsecureMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, + SecureSessionMgr * mgr) override; void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override; void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) override; - - // TransportMgrDelegate interface for rendezvous sessions - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf) override; }; } // namespace Messaging diff --git a/src/messaging/ReliableMessageContext.h b/src/messaging/ReliableMessageContext.h index f8fe36d454e8df..8a7a5fc1461b3d 100644 --- a/src/messaging/ReliableMessageContext.h +++ b/src/messaging/ReliableMessageContext.h @@ -219,6 +219,10 @@ class ReliableMessageContext /// When set, signifies that at least one message has been received from peer on this exchange context. kFlagMsgRcvdFromPeer = 0x0080, + + /// When set, the exchange context is handling secure messages, otherwise it is handling unsecure messages. + /// do not allow mixing secure and unsecure messaging in a single exchange + kIsSecure = 0x0100, }; BitFlags mFlags; // Internal state flags diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp index 40c2cdb28bae53..6c4916547b7ff1 100644 --- a/src/messaging/ReliableMessageMgr.cpp +++ b/src/messaging/ReliableMessageMgr.cpp @@ -26,7 +26,6 @@ #include #include -#include #include #include #include @@ -39,7 +38,7 @@ namespace Messaging { ReliableMessageMgr::RetransTableEntry::RetransTableEntry() : rc(nullptr), nextRetransTimeTick(0), sendCount(0) {} -ReliableMessageMgr::ReliableMessageMgr(std::array & contextPool) : +ReliableMessageMgr::ReliableMessageMgr(BitMapObjectPool & contextPool) : mContextPool(contextPool), mSystemLayer(nullptr), mSessionMgr(nullptr), mCurrentTimerExpiry(0), mTimerIntervalShift(CHIP_CONFIG_RMP_TIMER_DEFAULT_PERIOD_SHIFT) {} @@ -351,11 +350,7 @@ CHIP_ERROR ReliableMessageMgr::SendFromRetransTable(RetransTableEntry * entry) // over to someone else, and on failure it will no longer be available. msgId = entry->retainedBuf.GetMsgId(); - const ExchangeMessageDispatch * dispatcher = rc->GetExchangeContext()->GetMessageDispatch(); - VerifyOrExit(dispatcher != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - - err = - dispatcher->ResendMessage(rc->GetExchangeContext()->GetSecureSession(), std::move(entry->retainedBuf), &entry->retainedBuf); + err = rc->GetExchangeContext()->ResendMessage(std::move(entry->retainedBuf), &entry->retainedBuf); SuccessOrExit(err); // Update the counters diff --git a/src/messaging/ReliableMessageMgr.h b/src/messaging/ReliableMessageMgr.h index b66a2ba2a955b0..10480e6f0a7c42 100644 --- a/src/messaging/ReliableMessageMgr.h +++ b/src/messaging/ReliableMessageMgr.h @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -66,7 +67,7 @@ class ReliableMessageMgr }; public: - ReliableMessageMgr(std::array & contextPool); + ReliableMessageMgr(BitMapObjectPool & contextPool); ~ReliableMessageMgr(); void Init(chip::System::Layer * systemLayer, SecureSessionMgr * sessionMgr); @@ -223,7 +224,7 @@ class ReliableMessageMgr void TestSetIntervalShift(uint16_t value) { mTimerIntervalShift = value; } private: - std::array & mContextPool; + BitMapObjectPool & mContextPool; chip::System::Layer * mSystemLayer; SecureSessionMgr * mSessionMgr; uint64_t mTimeStampBase; // ReliableMessageProtocol timer base value to add offsets to evaluate timeouts @@ -234,10 +235,10 @@ class ReliableMessageMgr template void ExecuteForAllContext(Function function) { - for (auto & ec : mContextPool) - { - function(ec.GetReliableMessageContext()); - } + mContextPool.ForEachActiveObject([&](auto * ec) { + function(ec->GetReliableMessageContext()); + return true; + }); } void TicklessDebugDumpRetransTable(const char * log); diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index 3e67f870ce42ad..5c777aa03fb9d5 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -41,11 +41,11 @@ CHIP_ERROR MessagingContext::Init(nlTestSuite * suite, TransportMgrBase * transp ReturnErrorOnFailure(mExchangeManager.Init(&mSecureSessionMgr)); ReturnErrorOnFailure(mMessageCounterManager.Init(&mExchangeManager)); - ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(mPeer, GetDestinationNodeId(), &mPairingLocalToPeer, - SecureSession::SessionRole::kInitiator, mSrcAdminId)); + ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(Optional(mPeer), GetDestinationNodeId(), + &mPairingLocalToPeer, SecureSession::SessionRole::kInitiator, mSrcAdminId)); - return mSecureSessionMgr.NewPairing(mPeer, GetSourceNodeId(), &mPairingPeerToLocal, SecureSession::SessionRole::kResponder, - mDestAdminId); + return mSecureSessionMgr.NewPairing(Optional(mPeer), GetSourceNodeId(), &mPairingPeerToLocal, + SecureSession::SessionRole::kResponder, mDestAdminId); } // Shutdown all layers, finalize operations @@ -67,16 +67,19 @@ SecureSessionHandle MessagingContext::GetSessionPeerToLocal() return { GetSourceNodeId(), GetLocalKeyId(), GetAdminId() }; } -Messaging::ExchangeContext * MessagingContext::NewExchangeToPeer(Messaging::ExchangeDelegateBase * delegate) +Messaging::ExchangeContext * MessagingContext::NewSecureExchangeToPeer(Messaging::ExchangeDelegate * delegate) { - // TODO: temprary create a SecureSessionHandle from node id, will be fix in PR 3602 - return mExchangeManager.NewContext(GetSessionLocalToPeer(), delegate); + return mExchangeManager.NewSecureContext(GetSessionLocalToPeer(), delegate); } -Messaging::ExchangeContext * MessagingContext::NewExchangeToLocal(Messaging::ExchangeDelegateBase * delegate) +Messaging::ExchangeContext * MessagingContext::NewSecureExchangeToLocal(Messaging::ExchangeDelegate * delegate) { - // TODO: temprary create a SecureSessionHandle from node id, will be fix in PR 3602 - return mExchangeManager.NewContext(GetSessionPeerToLocal(), delegate); + return mExchangeManager.NewSecureContext(GetSessionPeerToLocal(), delegate); +} + +Messaging::ExchangeContext * MessagingContext::NewUnsecureExchange(Messaging::ExchangeDelegate * delegate) +{ + return mExchangeManager.NewUnsecureContext(mPeer, delegate); } } // namespace Test diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index 13acc1185c36fc..ac256382e8eda7 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -79,8 +79,9 @@ class MessagingContext : public IOContext SecureSessionHandle GetSessionLocalToPeer(); SecureSessionHandle GetSessionPeerToLocal(); - Messaging::ExchangeContext * NewExchangeToPeer(Messaging::ExchangeDelegateBase * delegate); - Messaging::ExchangeContext * NewExchangeToLocal(Messaging::ExchangeDelegateBase * delegate); + Messaging::ExchangeContext * NewSecureExchangeToPeer(Messaging::ExchangeDelegate * delegate); + Messaging::ExchangeContext * NewSecureExchangeToLocal(Messaging::ExchangeDelegate * delegate); + Messaging::ExchangeContext * NewUnsecureExchange(Messaging::ExchangeDelegate * delegate); Credentials::OperationalCredentialSet & GetOperationalCredentialSet() { return mOperationalCredentialSet; } @@ -93,7 +94,7 @@ class MessagingContext : public IOContext NodeId mDestinationNodeId = 111222333; uint16_t mLocalKeyId = 1; uint16_t mPeerKeyId = 2; - Optional mPeer; + Transport::PeerAddress mPeer; SecurePairingUsingTestSecret mPairingPeerToLocal; SecurePairingUsingTestSecret mPairingLocalToPeer; Transport::AdminPairingTable mAdmins; diff --git a/src/messaging/tests/TestExchangeMgr.cpp b/src/messaging/tests/TestExchangeMgr.cpp index 6e34e3902d7411..7f63f7e19c1596 100644 --- a/src/messaging/tests/TestExchangeMgr.cpp +++ b/src/messaging/tests/TestExchangeMgr.cpp @@ -94,7 +94,7 @@ void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) TestContext & ctx = *reinterpret_cast(inContext); MockAppDelegate mockAppDelegate; - ExchangeContext * ec1 = ctx.NewExchangeToLocal(&mockAppDelegate); + ExchangeContext * ec1 = ctx.NewSecureExchangeToLocal(&mockAppDelegate); NL_TEST_ASSERT(inSuite, ec1 != nullptr); NL_TEST_ASSERT(inSuite, ec1->IsInitiator() == true); NL_TEST_ASSERT(inSuite, ec1->GetExchangeId() != 0); @@ -103,7 +103,7 @@ void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, sessionPeerToLocal->GetPeerKeyID() == ctx.GetLocalKeyId()); NL_TEST_ASSERT(inSuite, ec1->GetDelegate() == &mockAppDelegate); - ExchangeContext * ec2 = ctx.NewExchangeToPeer(&mockAppDelegate); + ExchangeContext * ec2 = ctx.NewSecureExchangeToPeer(&mockAppDelegate); NL_TEST_ASSERT(inSuite, ec2 != nullptr); NL_TEST_ASSERT(inSuite, ec2->GetExchangeId() > ec1->GetExchangeId()); auto sessionLocalToPeer = ctx.GetSecureSessionManager().GetPeerConnectionState(ec2->GetSecureSession()); @@ -148,7 +148,7 @@ void CheckExchangeMessages(nlTestSuite * inSuite, void * inContext) // create solicited exchange MockAppDelegate mockSolicitedAppDelegate; - ExchangeContext * ec1 = ctx.NewExchangeToPeer(&mockSolicitedAppDelegate); + ExchangeContext * ec1 = ctx.NewSecureExchangeToPeer(&mockSolicitedAppDelegate); // create unsolicited exchange MockAppDelegate mockUnsolicitedAppDelegate; diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 0ccca4c76e8ede..327381df6cfda0 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -108,7 +108,7 @@ void CheckAddClearRetrans(nlTestSuite * inSuite, void * inContext) TestContext & ctx = *reinterpret_cast(inContext); MockAppDelegate mockAppDelegate; - ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockAppDelegate); + ExchangeContext * exchange = ctx.NewSecureExchangeToPeer(&mockAppDelegate); NL_TEST_ASSERT(inSuite, exchange != nullptr); ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); @@ -133,7 +133,7 @@ void CheckFailRetrans(nlTestSuite * inSuite, void * inContext) ctx.GetInetLayer().SystemLayer()->Init(nullptr); MockAppDelegate mockAppDelegate; - ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockAppDelegate); + ExchangeContext * exchange = ctx.NewSecureExchangeToPeer(&mockAppDelegate); NL_TEST_ASSERT(inSuite, exchange != nullptr); ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); @@ -166,7 +166,7 @@ void CheckResendMessage(nlTestSuite * inSuite, void * inContext) MockAppDelegate mockSender; // TODO: temprary create a SecureSessionHandle from node id, will be fix in PR 3602 - ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender); + ExchangeContext * exchange = ctx.NewSecureExchangeToPeer(&mockSender); NL_TEST_ASSERT(inSuite, exchange != nullptr); ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); @@ -205,7 +205,7 @@ void CheckSendStandaloneAckMessage(nlTestSuite * inSuite, void * inContext) ctx.GetInetLayer().SystemLayer()->Init(nullptr); MockAppDelegate mockAppDelegate; - ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockAppDelegate); + ExchangeContext * exchange = ctx.NewSecureExchangeToPeer(&mockAppDelegate); NL_TEST_ASSERT(inSuite, exchange != nullptr); ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); diff --git a/src/protocols/echo/EchoClient.cpp b/src/protocols/echo/EchoClient.cpp index 6feeafeca94e80..fe7c8c17bc76cb 100644 --- a/src/protocols/echo/EchoClient.cpp +++ b/src/protocols/echo/EchoClient.cpp @@ -68,7 +68,7 @@ CHIP_ERROR EchoClient::SendEchoRequest(System::PacketBufferHandle && payload, co } // Create a new exchange context. - mExchangeCtx = mExchangeMgr->NewContext(mSecureSession, this); + mExchangeCtx = mExchangeMgr->NewSecureContext(mSecureSession, this); if (mExchangeCtx == nullptr) { return CHIP_ERROR_NO_MEMORY; diff --git a/src/protocols/secure_channel/BUILD.gn b/src/protocols/secure_channel/BUILD.gn index 79be1cbd38accc..8e096fbf8ee3af 100644 --- a/src/protocols/secure_channel/BUILD.gn +++ b/src/protocols/secure_channel/BUILD.gn @@ -9,8 +9,6 @@ static_library("secure_channel") { "PASESession.cpp", "PASESession.h", "RendezvousParameters.h", - "SessionEstablishmentExchangeDispatch.cpp", - "SessionEstablishmentExchangeDispatch.h", "StatusReport.cpp", "StatusReport.h", ] diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 043add497812f4..a992a4857dd6d2 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -1155,8 +1155,6 @@ void CASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & p SuccessOrExit(err); } - mConnectionState.SetPeerAddress(mMessageDispatch.GetPeerAddress()); - switch (static_cast(payloadHeader.GetMessageType())) { case Protocols::SecureChannel::MsgType::CASE_SigmaR1: diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 0f4d8979ca5901..19ca50fe19efba 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -34,7 +34,6 @@ #include #include #include -#include #include #include #include @@ -71,7 +70,7 @@ struct CASESessionSerializable uint16_t mPeerKeyId; }; -class DLL_EXPORT CASESession : public Messaging::ExchangeDelegateBase, public PairingSession +class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public PairingSession { public: CASESession(); @@ -176,17 +175,10 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegateBase, public Pa **/ CHIP_ERROR FromSerializable(const CASESessionSerializable & output); - SessionEstablishmentExchangeDispatch & MessageDispatch() { return mMessageDispatch; } - //// ExchangeDelegate Implementation //// void OnMessageReceived(Messaging::ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, System::PacketBufferHandle payload) override; void OnResponseTimeout(Messaging::ExchangeContext * ec) override; - Messaging::ExchangeMessageDispatch * GetMessageDispatch(Messaging::ReliableMessageMgr * rmMgr, - SecureSessionMgr * sessionMgr) override - { - return &mMessageDispatch; - } private: enum SigmaErrorType : uint8_t @@ -253,7 +245,6 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegateBase, public Pa uint8_t mRemoteIPK[kIPKSize]; Messaging::ExchangeContext * mExchangeCtxt = nullptr; - SessionEstablishmentExchangeDispatch mMessageDispatch; struct SigmaErrorMsg { diff --git a/src/protocols/secure_channel/MessageCounterManager.cpp b/src/protocols/secure_channel/MessageCounterManager.cpp index ca8195153fa4da..930e9ddd168f7e 100644 --- a/src/protocols/secure_channel/MessageCounterManager.cpp +++ b/src/protocols/secure_channel/MessageCounterManager.cpp @@ -178,7 +178,7 @@ CHIP_ERROR MessageCounterManager::SendMsgCounterSyncReq(SecureSessionHandle sess System::PacketBufferHandle msgBuf; Messaging::SendFlags sendFlags; - exchangeContext = mExchangeMgr->NewContext(session, this); + exchangeContext = mExchangeMgr->NewSecureContext(session, this); VerifyOrExit(exchangeContext != nullptr, err = CHIP_ERROR_NO_MEMORY); msgBuf = MessagePacketBuffer::New(kChallengeSize); diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 65706c3bd555b9..9c658ece19b456 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -765,8 +765,6 @@ void PASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & p mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout); } - mConnectionState.SetPeerAddress(mMessageDispatch.GetPeerAddress()); - switch (static_cast(payloadHeader.GetMessageType())) { case Protocols::SecureChannel::MsgType::PBKDFParamRequest: diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index f65df90186a99c..f75aa2109053fa 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -32,9 +32,7 @@ #endif #include #include -#include #include -#include #include #include #include @@ -68,7 +66,7 @@ struct PASESessionSerializable typedef uint8_t PASEVerifier[2][kSpake2p_WS_Length]; -class DLL_EXPORT PASESession : public Messaging::ExchangeDelegateBase, public PairingSession +class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public PairingSession { public: PASESession(); @@ -199,8 +197,6 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegateBase, public Pa **/ void Clear(); - SessionEstablishmentExchangeDispatch & MessageDispatch() { return mMessageDispatch; } - //// ExchangeDelegate Implementation //// /** * @brief @@ -228,12 +224,6 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegateBase, public Pa */ void OnResponseTimeout(Messaging::ExchangeContext * ec) override; - Messaging::ExchangeMessageDispatch * GetMessageDispatch(Messaging::ReliableMessageMgr * rmMgr, - SecureSessionMgr * sessionMgr) override - { - return &mMessageDispatch; - } - private: enum Spake2pErrorType : uint8_t { @@ -288,8 +278,6 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegateBase, public Pa Messaging::ExchangeContext * mExchangeCtxt = nullptr; - SessionEstablishmentExchangeDispatch mMessageDispatch; - struct Spake2pErrorMsg { Spake2pErrorType error; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp deleted file mode 100644 index e0b38323325c2b..00000000000000 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ /dev/null @@ -1,86 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file provides implementation of Application Channel class. - */ - -#include -#include -#include - -namespace chip { - -using namespace Messaging; - -CHIP_ERROR SessionEstablishmentExchangeDispatch::SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) -{ - PacketHeader packetHeader; - - ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); - ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - - if (mTransportMgr != nullptr) - { - return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); - } - - return CHIP_ERROR_INCORRECT_STATE; -} - -CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, - const Transport::PeerAddress & peerAddress, - ReliableMessageContext * reliableMessageContext) -{ - mPeerAddress = peerAddress; - return ExchangeMessageDispatch::OnMessageReceived(payloadHeader, messageId, peerAddress, reliableMessageContext); -} - -bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) -{ - switch (protocol) - { - case Protocols::SecureChannel::Id.GetProtocolId(): - switch (type) - { - case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamRequest): - case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamResponse): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p1): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p2): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p3): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2pError): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR1): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR2): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR3): - case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaErr): - return true; - - default: - break; - } - break; - - default: - break; - } - return false; -} - -} // namespace chip diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h deleted file mode 100644 index b222d7b318ab7b..00000000000000 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file defines Application Channel class. The object of this - * class can be used by CHIP data model cluster applications to send - * and receive messages. The messages are encrypted using session keys. - */ - -#pragma once - -#include -#include - -namespace chip { - -class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDispatch -{ -public: - SessionEstablishmentExchangeDispatch() {} - - virtual ~SessionEstablishmentExchangeDispatch() {} - - CHIP_ERROR Init(TransportMgrBase * transportMgr) - { - ReturnErrorCodeIf(transportMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - mTransportMgr = transportMgr; - return CHIP_NO_ERROR; - } - - CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, - const Transport::PeerAddress & peerAddress, - Messaging::ReliableMessageContext * reliableMessageContext) override; - - const Transport::PeerAddress & GetPeerAddress() const { return mPeerAddress; } - - void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } - -protected: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override; - - bool MessagePermitted(uint16_t protocol, uint8_t type) override; - - bool IsReliableTransmissionAllowed() override - { - // If the underlying transport is UDP. - return (mPeerAddress.GetTransportType() == Transport::Type::kUdp); - } - -private: - TransportMgrBase * mTransportMgr = nullptr; - Transport::PeerAddress mPeerAddress; -}; - -} // namespace chip diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 4c7dc4096f83b1..793f2f78e11178 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -120,8 +120,7 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) TestCASESecurePairingDelegate delegate; CASESession pairing; - NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - ExchangeContext * context = ctx.NewExchangeToLocal(&pairing); + ExchangeContext * context = ctx.NewUnsecureExchange(&pairing); NL_TEST_ASSERT(inSuite, pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, 2, 0, nullptr, @@ -135,11 +134,10 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; CASESession pairing1; - NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); gLoopback.mSentMessageCount = 0; gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; - ExchangeContext * context1 = ctx.NewExchangeToLocal(&pairing1); + ExchangeContext * context1 = ctx.NewUnsecureExchange(&pairing1); NL_TEST_ASSERT(inSuite, pairing1.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, 2, 0, context1, @@ -159,14 +157,12 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte CASESessionSerializable serializableAccessory; gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( Protocols::SecureChannel::MsgType::CASE_SigmaR1, &pairingAccessory) == CHIP_NO_ERROR); - ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner); + ExchangeContext * contextCommissioner = ctx.NewUnsecureExchange(&pairingCommissioner); NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForSessionEstablishment(&accessoryDevOpCred, 0, &delegateAccessory) == CHIP_NO_ERROR); diff --git a/src/protocols/secure_channel/tests/TestMessageCounterManager.cpp b/src/protocols/secure_channel/tests/TestMessageCounterManager.cpp index b50a293e7d6c89..9f9bd76b9f0660 100644 --- a/src/protocols/secure_channel/tests/TestMessageCounterManager.cpp +++ b/src/protocols/secure_channel/tests/TestMessageCounterManager.cpp @@ -123,7 +123,7 @@ void CheckReceiveMessage(nlTestSuite * inSuite, void * inContext) System::PacketBufferHandle msgBuf = MessagePacketBuffer::NewWithData(PAYLOAD, payload_len); NL_TEST_ASSERT(inSuite, !msgBuf.IsNull()); - Messaging::ExchangeContext * ec = ctx.NewExchangeToPeer(nullptr); + Messaging::ExchangeContext * ec = ctx.NewSecureExchangeToPeer(nullptr); NL_TEST_ASSERT(inSuite, ec != nullptr); err = ec->SendMessage(chip::Protocols::Echo::MsgType::EchoRequest, std::move(msgBuf), diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index faa2698efc43f7..1539a96dca82e1 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -106,8 +106,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) PASESession pairing; - NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - ExchangeContext * context = ctx.NewExchangeToLocal(&pairing); + ExchangeContext * context = ctx.NewUnsecureExchange(&pairing); NL_TEST_ASSERT(inSuite, pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, nullptr, nullptr) != CHIP_NO_ERROR); @@ -120,8 +119,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; PASESession pairing1; - NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - ExchangeContext * context1 = ctx.NewExchangeToLocal(&pairing1); + ExchangeContext * context1 = ctx.NewUnsecureExchange(&pairing1); NL_TEST_ASSERT(inSuite, pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context1, &delegate) == CHIP_ERROR_BAD_REQUEST); @@ -138,14 +136,11 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( Protocols::SecureChannel::MsgType::PBKDFParamRequest, &pairingAccessory) == CHIP_NO_ERROR); - ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner); + ExchangeContext * contextCommissioner = ctx.NewUnsecureExchange(&pairingCommissioner); NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing(1234, 500, (const uint8_t *) "saltSALT", 8, 0, &delegateAccessory) == diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 689489c4533014..f637570c6d7814 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -313,7 +313,7 @@ void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const { PayloadHeader payloadHeader; ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); - mCB->OnMessageReceived(packetHeader, payloadHeader, SecureSessionHandle(), peerAddress, std::move(msg), this); + mCB->OnUnsecureMessageReceived(packetHeader, payloadHeader, peerAddress, std::move(msg), this); } } @@ -462,7 +462,7 @@ void SecureSessionMgr::SecureMessageDispatch(const PacketHeader & packetHeader, if (mCB != nullptr) { SecureSessionHandle session(state->GetPeerNodeId(), state->GetPeerKeyID(), state->GetAdminId()); - mCB->OnMessageReceived(packetHeader, payloadHeader, session, peerAddress, std::move(msg), this); + mCB->OnSecureMessageReceived(packetHeader, payloadHeader, session, peerAddress, std::move(msg), this); } exit: diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index e111d3e50695a0..1ffe4b337396f8 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -62,6 +62,7 @@ class EncryptedPacketBufferHandle final : private System::PacketBufferHandle void operator=(EncryptedPacketBufferHandle && aBuffer) { PacketBufferHandle::operator=(std::move(aBuffer)); } uint32_t GetMsgId() const; + bool IsNull() const { return PacketBufferHandle::IsNull(); } /** * Creates a copy of the data in this packet. @@ -123,9 +124,25 @@ class DLL_EXPORT SecureSessionMgrDelegate * @param msgBuf The received message * @param mgr A pointer to the SecureSessionMgr */ - virtual void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, - SecureSessionHandle session, const Transport::PeerAddress & source, - System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) + virtual void OnSecureMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + SecureSessionHandle session, const Transport::PeerAddress & source, + System::PacketBufferHandle msgBuf, SecureSessionMgr * mgr) + {} + + /** + * @brief + * Called when a new message is received. The function must internally release the + * msgBuf after processing it. + * + * @param packetHeader The message header + * @param payloadHeader The payload header + * @param source The sender's address + * @param msgBuf The received message + * @param mgr A pointer to the SecureSessionMgr + */ + virtual void OnUnsecureMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, + SecureSessionMgr * mgr) {} /** diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 821d32c8a82d0c..2931d06d227a1b 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -93,9 +93,9 @@ class OutgoingTransport : public Transport::Base class TestSessMgrCallback : public SecureSessionMgrDelegate { public: - void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, SecureSessionHandle session, - const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, - SecureSessionMgr * mgr) override + void OnSecureMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, SecureSessionHandle session, + const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, + SecureSessionMgr * mgr) override { NL_TEST_ASSERT(mSuite, header.GetSourceNodeId() == Optional::Value(kSourceNodeId)); NL_TEST_ASSERT(mSuite, header.GetDestinationNodeId() == Optional::Value(kDestinationNodeId));