From 79bedcf92139ad8b6204d1440f315cae0eb4b23b Mon Sep 17 00:00:00 2001 From: Marc Lepage <67919234+mlepage-google@users.noreply.github.com> Date: Wed, 3 May 2023 14:58:34 -0400 Subject: [PATCH] Cleanup in aisle CASESession (#26339) * Cleanup in aisle CASESession * Reduce nesting in function --- src/credentials/FabricTable.cpp | 2 +- src/credentials/FabricTable.h | 4 +- src/protocols/secure_channel/CASESession.cpp | 118 ++++++++++--------- 3 files changed, 64 insertions(+), 60 deletions(-) diff --git a/src/credentials/FabricTable.cpp b/src/credentials/FabricTable.cpp index 63f7afebd7d0ba..6d9f4781fff576 100644 --- a/src/credentials/FabricTable.cpp +++ b/src/credentials/FabricTable.cpp @@ -799,7 +799,7 @@ FabricTable::AddOrUpdateInner(FabricIndex fabricIndex, bool isAddition, Crypto:: } else { - // Initialization for Upating fabric: setting up a shadow fabricInfo + // Initialization for Updating fabric: setting up a shadow fabricInfo const FabricInfo * existingFabric = FindFabricWithIndex(fabricIndex); VerifyOrReturnError(existingFabric != nullptr, CHIP_ERROR_INTERNAL); diff --git a/src/credentials/FabricTable.h b/src/credentials/FabricTable.h index 1b01e1f4fd4a4f..82a82521d58be2 100644 --- a/src/credentials/FabricTable.h +++ b/src/credentials/FabricTable.h @@ -114,7 +114,7 @@ class DLL_EXPORT FabricInfo friend class FabricTable; -protected: +private: struct InitParams { NodeId nodeId = kUndefinedNodeId; @@ -1098,7 +1098,7 @@ class DLL_EXPORT FabricTable */ const FabricInfo * GetShadowPendingFabricEntry() const { return HasPendingFabricUpdate() ? &mPendingFabric : nullptr; } - // Returns true if we have a shadow entry pending for a fabruc update. + // Returns true if we have a shadow entry pending for a fabric update. bool HasPendingFabricUpdate() const { return mPendingFabric.IsInitialized() && diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 2e2fcda957b391..a8b0ac4e459a2f 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -150,18 +150,19 @@ class CASESession::WorkHelper // The `status` value is the result of the work callback (called beforehand). typedef CHIP_ERROR (CASESession::*AfterWorkCallback)(DATA & data, CHIP_ERROR status); - // Create a work helper using the specified session, work callback, after work callback, and data (template arg). - // Lifetime is not managed, see `Create` for that option. - WorkHelper(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) : - mSession(&session), mWorkCallback(workCallback), mAfterWorkCallback(afterWorkCallback) - {} - +public: // Create a work helper using the specified session, work callback, after work callback, and data (template arg). // Lifetime is managed by sharing between the caller (typically the session) and the helper itself (while work is scheduled). static Platform::SharedPtr Create(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) { - auto ptr = Platform::MakeShared(session, workCallback, afterWorkCallback); + struct EnableShared : public WorkHelper + { + EnableShared(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) : + WorkHelper(session, workCallback, afterWorkCallback) + {} + }; + auto ptr = Platform::MakeShared(session, workCallback, afterWorkCallback); if (ptr) { ptr->mWeakPtr = ptr; // used by `ScheduleWork` @@ -173,10 +174,7 @@ class CASESession::WorkHelper // No scheduling, no outstanding work, no shared lifetime management. CHIP_ERROR DoWork() { - if (!mSession || !mWorkCallback || !mAfterWorkCallback) - { - return CHIP_ERROR_INCORRECT_STATE; - } + VerifyOrReturnError(mSession && mWorkCallback && mAfterWorkCallback, CHIP_ERROR_INCORRECT_STATE); auto * helper = this; bool cancel = false; helper->mStatus = helper->mWorkCallback(helper->mData, cancel); @@ -187,18 +185,17 @@ class CASESession::WorkHelper return helper->mStatus; } - // Schedule the work after configuring the data. + // Schedule the work for later execution. // If lifetime is managed, the helper shares management while work is outstanding. CHIP_ERROR ScheduleWork() { - if (!mSession || !mWorkCallback || !mAfterWorkCallback) - { - return CHIP_ERROR_INCORRECT_STATE; - } + VerifyOrReturnError(mSession && mWorkCallback && mAfterWorkCallback, CHIP_ERROR_INCORRECT_STATE); + // Hold strong ptr while work is outstanding mStrongPtr = mWeakPtr.lock(); // set in `Create` auto status = DeviceLayer::PlatformMgr().ScheduleBackgroundWork(WorkHandler, reinterpret_cast(this)); if (status != CHIP_NO_ERROR) { + // Release strong ptr since scheduling failed mStrongPtr.reset(); } return status; @@ -207,32 +204,47 @@ class CASESession::WorkHelper // Cancel the work, by clearing the associated session. void CancelWork() { mSession.store(nullptr); } + bool IsCancelled() const { return mSession.load() == nullptr; } + private: + // Create a work helper using the specified session, work callback, after work callback, and data (template arg). + // Lifetime is not managed, see `Create` for that option. + WorkHelper(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) : + mSession(&session), mWorkCallback(workCallback), mAfterWorkCallback(afterWorkCallback) + {} + // Handler for the work callback. static void WorkHandler(intptr_t arg) { auto * helper = reinterpret_cast(arg); - bool cancel = false; - VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`? + // Hold strong ptr while work is handled + auto strongPtr(std::move(helper->mStrongPtr)); + VerifyOrReturn(!helper->IsCancelled()); + bool cancel = false; + // Execute callback in background thread; data must be OK with this helper->mStatus = helper->mWorkCallback(helper->mData, cancel); - VerifyOrExit(!cancel, ;); // canceled by `mWorkCallback`? - VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`? - SuccessOrExit(DeviceLayer::PlatformMgr().ScheduleWork(AfterWorkHandler, reinterpret_cast(helper))); - return; - exit: - helper->mStrongPtr.reset(); + VerifyOrReturn(!cancel && !helper->IsCancelled()); + // Hold strong ptr while work is outstanding + helper->mStrongPtr.swap(strongPtr); + auto status = DeviceLayer::PlatformMgr().ScheduleWork(AfterWorkHandler, reinterpret_cast(helper)); + if (status != CHIP_NO_ERROR) + { + // Release strong ptr since scheduling failed + helper->mStrongPtr.reset(); + } } // Handler for the after work callback. static void AfterWorkHandler(intptr_t arg) { - // Since this runs in the main Matter thread, the session shouldn't be otherwise used (messages, timers, etc.) auto * helper = reinterpret_cast(arg); + // Hold strong ptr while work is handled + auto strongPtr(std::move(helper->mStrongPtr)); if (auto * session = helper->mSession.load()) { + // Execute callback in Matter thread; session should be OK with this (session->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus); } - helper->mStrongPtr.reset(); } private: @@ -261,7 +273,7 @@ class CASESession::WorkHelper struct CASESession::SendSigma3Data { - std::atomic fabricIndex; + FabricIndex fabricIndex; // Use one or the other const FabricTable * fabricTable; @@ -319,7 +331,6 @@ void CASESession::Clear() // Cancel any outstanding work. if (mSendSigma3Helper) { - mSendSigma3Helper->mData.fabricIndex = kUndefinedFabricIndex; mSendSigma3Helper->CancelWork(); mSendSigma3Helper.reset(); } @@ -1359,40 +1370,37 @@ CHIP_ERROR CASESession::SendSigma3a() CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel) { - CHIP_ERROR err = CHIP_NO_ERROR; - // Generate a signature if (data.keystore != nullptr) { // Recommended case: delegate to operational keystore - err = data.keystore->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, - data.tbsData3Signature); + ReturnErrorOnFailure(data.keystore->SignWithOpKeypair( + data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, data.tbsData3Signature)); } else { // Legacy case: delegate to fabric table fabric info - err = data.fabricTable->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, - data.tbsData3Signature); + ReturnErrorOnFailure(data.fabricTable->SignWithOpKeypair( + data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, data.tbsData3Signature)); } - SuccessOrExit(err); // Prepare Sigma3 TBE Data Blob data.msg_r3_encrypted_len = TLV::EstimateStructOverhead(data.nocCert.size(), data.icaCert.size(), data.tbsData3Signature.Length()); - VerifyOrExit(data.msg_R3_Encrypted.Alloc(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES), - err = CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(data.msg_R3_Encrypted.Alloc(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES), + CHIP_ERROR_NO_MEMORY); { TLV::TLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; tlvWriter.Init(data.msg_R3_Encrypted.Get(), data.msg_r3_encrypted_len); - SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert)); + ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert)); if (!data.icaCert.empty()) { - SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert)); } // We are now done with ICAC and NOC certs so we can release the memory. @@ -1404,15 +1412,14 @@ CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel) data.nocCert = MutableByteSpan{}; } - SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(), - static_cast(data.tbsData3Signature.Length()))); - SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); - SuccessOrExit(err = tlvWriter.Finalize()); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(), + static_cast(data.tbsData3Signature.Length()))); + ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Finalize()); data.msg_r3_encrypted_len = static_cast(tlvWriter.GetLengthWritten()); } -exit: - return err; + return CHIP_NO_ERROR; } CHIP_ERROR CASESession::SendSigma3c(SendSigma3Data & data, CHIP_ERROR status) @@ -1650,17 +1657,15 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel) { - CHIP_ERROR err = CHIP_NO_ERROR; - // Step 5/6 // Validate initiator identity located in msg->Start() // Constructing responder identity CompressedFabricId unused; FabricId initiatorFabricId; P256PublicKey initiatorPublicKey; - SuccessOrExit(err = FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext, - unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey)); - VerifyOrExit(data.fabricId == initiatorFabricId, err = CHIP_ERROR_INVALID_CASE_PARAMETER); + ReturnErrorOnFailure(FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext, + unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey)); + VerifyOrReturnError(data.fabricId == initiatorFabricId, CHIP_ERROR_INVALID_CASE_PARAMETER); // TODO - Validate message signature prior to validating the received operational credentials. // The op cert check requires traversal of cert chain, that is a more expensive operation. @@ -1672,16 +1677,15 @@ CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel) { P256PublicKeyHSM initiatorPublicKeyHSM; memcpy(Uint8::to_uchar(initiatorPublicKeyHSM), initiatorPublicKey.Bytes(), initiatorPublicKey.Length()); - SuccessOrExit(err = initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, - data.tbsData3Signature)); + ReturnErrorOnFailure(initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, + data.tbsData3Signature)); } #else - SuccessOrExit(err = initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, - data.tbsData3Signature)); + ReturnErrorOnFailure( + initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, data.tbsData3Signature)); #endif -exit: - return err; + return CHIP_NO_ERROR; } CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Data & data, CHIP_ERROR status)