diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 2996cf2dd891d4..179e2a3df5e120 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -92,7 +92,8 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() } void OperationalSessionSetup::Connect(Callback::Callback * onConnection, - Callback::Callback * onFailure) + Callback::Callback * onFailure, + Callback::Callback * onSetupFailure) { CHIP_ERROR err = CHIP_NO_ERROR; bool isConnected = false; @@ -102,7 +103,7 @@ void OperationalSessionSetup::Connect(Callback::Callback * on // If anything goes wrong below, we'll trigger failures (including any queued from // a previous iteration which in theory shouldn't happen, but this is written to be more defensive) // - EnqueueConnectionCallbacks(onConnection, onFailure); + EnqueueConnectionCallbacks(onConnection, onFailure, onSetupFailure); switch (mState) { @@ -178,6 +179,18 @@ void OperationalSessionSetup::Connect(Callback::Callback * on } } +void OperationalSessionSetup::Connect(Callback::Callback * onConnection, + Callback::Callback * onFailure) +{ + Connect(onConnection, onFailure, nullptr); +} + +void OperationalSessionSetup::Connect(Callback::Callback * onConnection, + Callback::Callback * onSetupFailure) +{ + Connect(onConnection, nullptr, onSetupFailure); +} + void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & addr, const ReliableMessageProtocolConfig & config) { #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES @@ -291,7 +304,8 @@ CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessagePro } void OperationalSessionSetup::EnqueueConnectionCallbacks(Callback::Callback * onConnection, - Callback::Callback * onFailure) + Callback::Callback * onFailure, + Callback::Callback * onSetupFailure) { if (onConnection != nullptr) { @@ -302,11 +316,17 @@ void OperationalSessionSetup::EnqueueConnectionCallbacks(Callback::CallbackCancel()); } + + if (onSetupFailure != nullptr) + { + mSetupFailure.Enqueue(onSetupFailure->Cancel()); + } } -void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, ReleaseBehavior releaseBehavior) +void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, SessionEstablishmentStage stage, + ReleaseBehavior releaseBehavior) { - Cancelable failureReady, successReady; + Cancelable failureReady, setupFailureReady, successReady; // // Dequeue both failure and success callback lists into temporary stack args before invoking either of them. @@ -314,6 +334,7 @@ void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, Relea // since the callee may destroy this object as part of that callback. // mConnectionFailure.DequeueAll(failureReady); + mSetupFailure.DequeueAll(setupFailureReady); mConnectionSuccess.DequeueAll(successReady); #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES @@ -339,13 +360,14 @@ void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, Relea // DO NOT touch any members of this object after this point. It's dead. - NotifyConnectionCallbacks(failureReady, successReady, error, peerId, performingAddressUpdate, exchangeMgr, - optionalSessionHandle); + NotifyConnectionCallbacks(failureReady, setupFailureReady, successReady, error, stage, peerId, performingAddressUpdate, + exchangeMgr, optionalSessionHandle); } -void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureReady, Cancelable & successReady, CHIP_ERROR error, - const ScopedNodeId & peerId, bool performingAddressUpdate, - Messaging::ExchangeManager * exchangeMgr, +void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureReady, Cancelable & setupFailureReady, + Cancelable & successReady, CHIP_ERROR error, + SessionEstablishmentStage stage, const ScopedNodeId & peerId, + bool performingAddressUpdate, Messaging::ExchangeManager * exchangeMgr, const Optional & optionalSessionHandle) { // @@ -367,6 +389,22 @@ void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureRead } } + while (setupFailureReady.mNext != &setupFailureReady) + { + // We expect that we only have callbacks if we are not performing just address update. + VerifyOrDie(!performingAddressUpdate); + Callback::Callback * cb = Callback::Callback::FromCancelable(setupFailureReady.mNext); + + cb->Cancel(); + + if (error != CHIP_NO_ERROR) + { + // Initialize the ConnnectionFailureInfo object + ConnnectionFailureInfo failureInfo(peerId, error, stage); + cb->mCall(cb->mContext, failureInfo); + } + } + while (successReady.mNext != &successReady) { // We expect that we only have callbacks if we are not performing just address update. @@ -383,7 +421,7 @@ void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureRead } } -void OperationalSessionSetup::OnSessionEstablishmentError(CHIP_ERROR error) +void OperationalSessionSetup::OnSessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) { VerifyOrReturn(mState == State::Connecting, ChipLogError(Discovery, "OnSessionEstablishmentError was called while we were not connecting")); @@ -438,7 +476,7 @@ void OperationalSessionSetup::OnSessionEstablishmentError(CHIP_ERROR error) #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES } - DequeueConnectionCallbacks(error); + DequeueConnectionCallbacks(error, stage); // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks. } diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index 2925259066e6db..45b571a08aa633 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -155,6 +155,19 @@ typedef void (*OnDeviceConnectionRetry)(void * context, const ScopedNodeId & pee class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, public AddressResolve::NodeListener { public: + struct ConnnectionFailureInfo + { + const ScopedNodeId peerId; + CHIP_ERROR error; + SessionEstablishmentStage sessionStage; + + ConnnectionFailureInfo(const ScopedNodeId & peer, CHIP_ERROR err, SessionEstablishmentStage stage) : + peerId(peer), error(err), sessionStage(stage) + {} + }; + + using OnSetupFailure = void (*)(void * context, const ConnnectionFailureInfo & failureInfo); + ~OperationalSessionSetup() override; OperationalSessionSetup(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, ScopedNodeId peerId, @@ -180,8 +193,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * The device is expected to have been commissioned, A CASE session * setup will be triggered. * - * On establishing the session, the callback function `onConnection` will be called. If the - * session setup fails, `onFailure` will be called. + * If session setup succeeds, the callback function `onConnection` will be called. + * If session setup fails, `onFailure` will be called. * * If the session already exists, `onConnection` will be called immediately, * before the Connect call returns. @@ -192,11 +205,28 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, */ void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure); + /* + * This function can be called to establish a secure session with the device. + * + * The device is expected to have been commissioned, A CASE session + * setup will be triggered. + * + * If session setup succeeds, the callback function `onConnection` will be called. + * If session setup fails, `onSetupFailure` will be called. + * + * If the session already exists, `onConnection` will be called immediately, + * before the Connect call returns. + * + * `onSetupFailure` may be called before the Connect call returns, for error cases that are detected synchronously + * (e.g. inability to start an address lookup). + */ + void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure); + bool IsForAddressUpdate() const { return mPerformingAddressUpdate; } //////////// SessionEstablishmentDelegate Implementation /////////////// void OnSessionEstablished(const SessionHandle & session) override; - void OnSessionEstablishmentError(CHIP_ERROR error) override; + void OnSessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) override; ScopedNodeId GetPeerId() const { return mPeerId; } @@ -264,6 +294,7 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, Callback::CallbackDeque mConnectionSuccess; Callback::CallbackDeque mConnectionFailure; + Callback::CallbackDeque mSetupFailure; OperationalSessionReleaseDelegate * mReleaseDelegate; @@ -306,8 +337,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, void CleanupCASEClient(); + void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, + Callback::Callback * onSetupFailure); + void EnqueueConnectionCallbacks(Callback::Callback * onConnection, - Callback::Callback * onFailure); + Callback::Callback * onFailure, + Callback::Callback * onSetupFailure); enum class ReleaseBehavior { @@ -316,11 +351,13 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, }; /* - * This dequeues all failure and success callbacks and appropriately - * invokes either set depending on the value of error. + * This dequeues all failure and success callbacks and appropriately invokes either set depending + * on the value of error. + * + * If error == CHIP_NO_ERROR, only success callbacks are invoked. Otherwise, only failure callbacks are invoked. * - * If error == CHIP_NO_ERROR, only success callbacks are invoked. - * Otherwise, only failure callbacks are invoked. + * The state offers additional context regarding the failure, indicating the specific state in which + * the error occurs. It is only relayed through failure callbacks when the error is not equal to CHIP_NO_ERROR. * * If releaseBehavior is Release, this uses mReleaseDelegate to release * ourselves (aka `this`). As a result any caller should return right away @@ -328,15 +365,22 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * * Setting releaseBehavior to DoNotRelease is meant for use from the destructor */ - void DequeueConnectionCallbacks(CHIP_ERROR error, ReleaseBehavior releaseBehavior = ReleaseBehavior::Release); + void DequeueConnectionCallbacks(CHIP_ERROR error, SessionEstablishmentStage stage, + ReleaseBehavior releaseBehavior = ReleaseBehavior::Release); + + void DequeueConnectionCallbacks(CHIP_ERROR error, ReleaseBehavior releaseBehavior = ReleaseBehavior::Release) + { + this->DequeueConnectionCallbacks(error, SessionEstablishmentStage::kNotInKeyExchange, releaseBehavior); + } /** * Helper for DequeueConnectionCallbacks that handles the actual callback * notifications. This happens after the object has been released, if it's * being released. */ - static void NotifyConnectionCallbacks(Callback::Cancelable & failureReady, Callback::Cancelable & successReady, - CHIP_ERROR error, const ScopedNodeId & peerId, bool performingAddressUpdate, + static void NotifyConnectionCallbacks(Callback::Cancelable & failureReady, Callback::Cancelable & setupFailureReady, + Callback::Cancelable & successReady, CHIP_ERROR error, SessionEstablishmentStage stage, + const ScopedNodeId & peerId, bool performingAddressUpdate, Messaging::ExchangeManager * exchangeMgr, const Optional & optionalSessionHandle); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 6e51d000e3822c..296ba0848150c7 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -556,9 +556,11 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec) void CASESession::AbortPendingEstablish(CHIP_ERROR err) { + // This needs to come before Clear() which will reset mState. + SessionEstablishmentStage state = MapCASEStateToSessionEstablishmentStage(mState); Clear(); // Do this last in case the delegate frees us. - NotifySessionEstablishmentError(err); + NotifySessionEstablishmentError(err, state); } CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session) const @@ -2255,4 +2257,29 @@ bool CASESession::InvokeBackgroundWorkWatchdog() return watchdogFired; } +// Helper function to map CASESession::State to SessionEstablishmentStage +SessionEstablishmentStage CASESession::MapCASEStateToSessionEstablishmentStage(State caseState) +{ + switch (caseState) + { + case State::kInitialized: + return SessionEstablishmentStage::kNotInKeyExchange; + case State::kSentSigma1: + case State::kSentSigma1Resume: + return SessionEstablishmentStage::kSentSigma1; + case State::kSentSigma2: + case State::kSentSigma2Resume: + return SessionEstablishmentStage::kSentSigma2; + case State::kSendSigma3Pending: + return SessionEstablishmentStage::kReceivedSigma2; + case State::kSentSigma3: + return SessionEstablishmentStage::kSentSigma3; + case State::kHandleSigma3Pending: + return SessionEstablishmentStage::kReceivedSigma3; + // Add more mappings here for other states + default: + return SessionEstablishmentStage::kUnknown; // Default mapping + } +} + } // namespace chip diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 7453b6b5002dc4..6fc58dfb90dc83 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -320,6 +320,8 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, #if CONFIG_BUILD_FOR_HOST_UNIT_TEST Optional mStopHandshakeAtState = Optional::Missing(); #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST + + SessionEstablishmentStage MapCASEStateToSessionEstablishmentStage(State caseState); }; } // namespace chip diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 63a1701e66541f..23daea30f2800c 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -255,7 +255,7 @@ void PairingSession::Clear() mSessionManager = nullptr; } -void PairingSession::NotifySessionEstablishmentError(CHIP_ERROR error) +void PairingSession::NotifySessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) { if (mDelegate == nullptr) { @@ -265,7 +265,7 @@ void PairingSession::NotifySessionEstablishmentError(CHIP_ERROR error) auto * delegate = mDelegate; mDelegate = nullptr; - delegate->OnSessionEstablishmentError(error); + delegate->OnSessionEstablishmentError(error, stage); } void PairingSession::OnSessionReleased() diff --git a/src/protocols/secure_channel/PairingSession.h b/src/protocols/secure_channel/PairingSession.h index 844fa33a41ae68..ffbb6d9966485d 100644 --- a/src/protocols/secure_channel/PairingSession.h +++ b/src/protocols/secure_channel/PairingSession.h @@ -218,10 +218,14 @@ class DLL_EXPORT PairingSession : public SessionDelegate void Clear(); /** - * Notify our delegate about a session establishment error, if we have not - * notified it of an error or success before. + * Notify our delegate about a session establishment error and the stage when the error occurs + * if we have not already notified it of an error or success before. + * + * @param error The error code to report. + * @param stage The stage of the session when the error occurs, defaults to kNotInKeyExchange. */ - void NotifySessionEstablishmentError(CHIP_ERROR error); + void NotifySessionEstablishmentError(CHIP_ERROR error, + SessionEstablishmentStage stage = SessionEstablishmentStage::kNotInKeyExchange); protected: CryptoContext::SessionRole mRole; diff --git a/src/protocols/secure_channel/SessionEstablishmentDelegate.h b/src/protocols/secure_channel/SessionEstablishmentDelegate.h index dc73a0ffe6997d..640dfca745921f 100644 --- a/src/protocols/secure_channel/SessionEstablishmentDelegate.h +++ b/src/protocols/secure_channel/SessionEstablishmentDelegate.h @@ -32,6 +32,18 @@ namespace chip { +enum class SessionEstablishmentStage : uint8_t +{ + kUnknown = 0, + kNotInKeyExchange = 1, + kSentSigma1 = 2, + kReceivedSigma1 = 3, + kSentSigma2 = 4, + kReceivedSigma2 = 5, + kSentSigma3 = 6, + kReceivedSigma3 = 7, +}; + class DLL_EXPORT SessionEstablishmentDelegate { public: @@ -39,9 +51,22 @@ class DLL_EXPORT SessionEstablishmentDelegate * Called when session establishment fails with an error. This will be * called at most once per session establishment and will not be called if * OnSessionEstablished is called. + * + * This overload of OnSessionEstablishmentError is not called directly. It's only called from the default + *. implemetation of the two-argument overload. */ virtual void OnSessionEstablishmentError(CHIP_ERROR error) {} + /** + * Called when session establishment fails with an error and state at the + * failure. This will be called at most once per session establishment and + * will not be called if OnSessionEstablished is called. + */ + virtual void OnSessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) + { + OnSessionEstablishmentError(error); + } + /** * Called on start of session establishment process */