From 7e883fafa6720d0baa4118a90dbc461effa1338b Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Wed, 6 Apr 2022 11:54:50 +0800 Subject: [PATCH] Implement CASE session resumption (#16741) * Implement CASE session resumption * Fix span usage * Introduce SimpleSessionResumptionStorage as an example implementation. * Allow inject session resumption storage for server --- examples/pump-app/cc13x2x7_26x2x7/args.gni | 3 + examples/tv-casting-app/linux/main.cpp | 9 +- src/app/BUILD.gn | 6 +- src/app/CASEClient.cpp | 3 +- src/app/CASEClient.h | 9 +- src/app/OperationalDeviceProxy.cpp | 6 +- src/app/OperationalDeviceProxy.h | 12 +- src/app/server/Server.cpp | 6 +- src/app/server/Server.h | 21 ++ src/app/tests/TestOperationalDeviceProxy.cpp | 11 +- .../CHIPDeviceControllerFactory.cpp | 29 +- .../CHIPDeviceControllerSystemState.h | 1 + src/credentials/FabricTable.h | 3 + src/crypto/CHIPCryptoPAL.h | 12 + src/lib/core/CHIPConfig.h | 2 +- src/lib/support/DefaultStorageKeyAllocator.h | 8 + src/lib/support/Span.h | 19 +- src/protocols/secure_channel/BUILD.gn | 6 +- src/protocols/secure_channel/CASEServer.cpp | 15 +- src/protocols/secure_channel/CASEServer.h | 4 +- src/protocols/secure_channel/CASESession.cpp | 214 +++++++------- src/protocols/secure_channel/CASESession.h | 52 +--- .../secure_channel/CASESessionCache.cpp | 105 ------- .../secure_channel/CASESessionCache.h | 45 --- .../SessionResumptionStorage.cpp | 151 ++++++++++ .../secure_channel/SessionResumptionStorage.h | 82 ++++++ .../SimpleSessionResumptionStorage.cpp | 273 ++++++++++++++++++ .../SimpleSessionResumptionStorage.h | 88 ++++++ src/protocols/secure_channel/tests/BUILD.gn | 2 +- .../secure_channel/tests/TestCASESession.cpp | 36 +-- .../tests/TestCASESessionCache.cpp | 233 --------------- .../secure_channel/tests/TestPASESession.cpp | 3 - .../TestSimpleSessionResumptionStorage.cpp | 145 ++++++++++ src/transport/PairingSession.h | 4 +- 34 files changed, 1014 insertions(+), 604 deletions(-) delete mode 100644 src/protocols/secure_channel/CASESessionCache.cpp delete mode 100644 src/protocols/secure_channel/CASESessionCache.h create mode 100644 src/protocols/secure_channel/SessionResumptionStorage.cpp create mode 100644 src/protocols/secure_channel/SessionResumptionStorage.h create mode 100644 src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp create mode 100644 src/protocols/secure_channel/SimpleSessionResumptionStorage.h delete mode 100644 src/protocols/secure_channel/tests/TestCASESessionCache.cpp create mode 100644 src/protocols/secure_channel/tests/TestSimpleSessionResumptionStorage.cpp diff --git a/examples/pump-app/cc13x2x7_26x2x7/args.gni b/examples/pump-app/cc13x2x7_26x2x7/args.gni index 2b1db21d690d61..7c22fa4d2f50b7 100644 --- a/examples/pump-app/cc13x2x7_26x2x7/args.gni +++ b/examples/pump-app/cc13x2x7_26x2x7/args.gni @@ -35,6 +35,9 @@ chip_enable_ota_requestor = true # BLE options chip_config_network_layer_ble = true +# Disable session resumption due to lack of code space +chip_enable_session_resumption = false + # Disable lock tracking, since our FreeRTOS configuration does not set # INCLUDE_xSemaphoreGetMutexHolder chip_stack_lock_tracking = "none" diff --git a/examples/tv-casting-app/linux/main.cpp b/examples/tv-casting-app/linux/main.cpp index 3050f8ef40f5e5..d6491ade593911 100644 --- a/examples/tv-casting-app/linux/main.cpp +++ b/examples/tv-casting-app/linux/main.cpp @@ -322,10 +322,11 @@ class TargetVideoPlayerInfo } chip::DeviceProxyInitParams initParams = { - .sessionManager = &(server->GetSecureSessionManager()), - .exchangeMgr = &(server->GetExchangeManager()), - .fabricTable = &(server->GetFabricTable()), - .clientPool = &gCASEClientPool, + .sessionManager = &(server->GetSecureSessionManager()), + .sessionResumptionStorage = server->GetSessionResumptionStorage(), + .exchangeMgr = &(server->GetExchangeManager()), + .fabricTable = &(server->GetFabricTable()), + .clientPool = &gCASEClientPool, }; PeerId peerID = fabric->GetPeerIdForNode(nodeId); diff --git a/src/app/BUILD.gn b/src/app/BUILD.gn index 8f7f06011d427a..a7023daade0609 100644 --- a/src/app/BUILD.gn +++ b/src/app/BUILD.gn @@ -22,13 +22,17 @@ declare_args() { # Enable strict schema checks. chip_enable_schema_check = is_debug && (current_os == "linux" || current_os == "mac") + chip_enable_session_resumption = true } buildconfig_header("app_buildconfig") { header = "AppBuildConfig.h" header_dir = "app" - defines = [ "CHIP_CONFIG_IM_ENABLE_SCHEMA_CHECK=${chip_enable_schema_check}" ] + defines = [ + "CHIP_CONFIG_IM_ENABLE_SCHEMA_CHECK=${chip_enable_schema_check}", + "CHIP_CONFIG_ENABLE_SESSION_RESUMPTION=${chip_enable_session_resumption}", + ] } static_library("app") { diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index 94666eb4a8d503..2ec8ce9780646e 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -46,7 +46,8 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider); ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, peerAddress, mInitParams.fabricInfo, - peer.GetNodeId(), exchange, this, mInitParams.mrpLocalConfig)); + peer.GetNodeId(), exchange, mInitParams.sessionResumptionStorage, this, + mInitParams.mrpLocalConfig)); mConnectionSuccessCallback = onConnection; mConnectionFailureCallback = onFailure; mConectionContext = context; diff --git a/src/app/CASEClient.h b/src/app/CASEClient.h index fba48029a0ecca..220e920904252f 100644 --- a/src/app/CASEClient.h +++ b/src/app/CASEClient.h @@ -31,10 +31,11 @@ typedef void (*OnCASEConnectionFailure)(void * context, CASEClient * client, CHI struct CASEClientInitParams { - SessionManager * sessionManager = nullptr; - Messaging::ExchangeManager * exchangeMgr = nullptr; - FabricInfo * fabricInfo = nullptr; - Credentials::GroupDataProvider * groupDataProvider = nullptr; + SessionManager * sessionManager = nullptr; + SessionResumptionStorage * sessionResumptionStorage = nullptr; + Messaging::ExchangeManager * exchangeMgr = nullptr; + FabricInfo * fabricInfo = nullptr; + Credentials::GroupDataProvider * groupDataProvider = nullptr; Optional mrpLocalConfig = Optional::Missing(); }; diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index 0ce22499730342..ef7ea1ce0738b6 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -167,9 +167,9 @@ bool OperationalDeviceProxy::GetAddress(Inet::IPAddress & addr, uint16_t & port) CHIP_ERROR OperationalDeviceProxy::EstablishConnection() { - mCASEClient = - mInitParams.clientPool->Allocate(CASEClientInitParams{ mInitParams.sessionManager, mInitParams.exchangeMgr, mFabricInfo, - mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); + mCASEClient = mInitParams.clientPool->Allocate( + CASEClientInitParams{ mInitParams.sessionManager, mInitParams.sessionResumptionStorage, mInitParams.exchangeMgr, + mFabricInfo, mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, mMRPConfig, HandleCASEConnected, HandleCASEConnectionFailure, this); diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 16e4f6e63b6594..7f1b9d715b23f6 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -48,17 +48,19 @@ namespace chip { struct DeviceProxyInitParams { - SessionManager * sessionManager = nullptr; - Messaging::ExchangeManager * exchangeMgr = nullptr; - FabricTable * fabricTable = nullptr; - CASEClientPoolDelegate * clientPool = nullptr; - Credentials::GroupDataProvider * groupDataProvider = nullptr; + SessionManager * sessionManager = nullptr; + SessionResumptionStorage * sessionResumptionStorage = nullptr; + Messaging::ExchangeManager * exchangeMgr = nullptr; + FabricTable * fabricTable = nullptr; + CASEClientPoolDelegate * clientPool = nullptr; + Credentials::GroupDataProvider * groupDataProvider = nullptr; Optional mrpLocalConfig = Optional::Missing(); CHIP_ERROR Validate() const { ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); + // sessionResumptionStorage can be nullptr when resumption is disabled ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE); diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 3ce73a3970b30c..d51cbc3a9ee400 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -115,7 +115,8 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) mCommissioningWindowManager.SetAppDelegate(initParams.appDelegate); // Initialize PersistentStorageDelegate-based storage - mDeviceStorage = initParams.persistentStorageDelegate; + mDeviceStorage = initParams.persistentStorageDelegate; + mSessionResumptionStorage = initParams.sessionResumptionStorage; // Set up attribute persistence before we try to bring up the data model // handler. @@ -237,6 +238,7 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) caseSessionManagerConfig = { .sessionInitParams = { .sessionManager = &mSessions, + .sessionResumptionStorage = mSessionResumptionStorage, .exchangeMgr = &mExchangeMgr, .fabricTable = &mFabrics, .clientPool = &mCASEClientPool, @@ -255,7 +257,7 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) #if CONFIG_NETWORK_LAYER_BLE chip::DeviceLayer::ConnectivityMgr().GetBleLayer(), #endif - &mSessions, &mFabrics, mGroupsProvider); + &mSessions, &mFabrics, mSessionResumptionStorage, mGroupsProvider); SuccessOrExit(err); // This code is necessary to restart listening to existing groups after a reboot diff --git a/src/app/server/Server.h b/src/app/server/Server.h index b97c4cce95e472..406f0e27ba95e3 100644 --- a/src/app/server/Server.h +++ b/src/app/server/Server.h @@ -17,6 +17,8 @@ #pragma once +#include + #include #include #include @@ -38,6 +40,9 @@ #include #include #include +#if CHIP_CONFIG_ENABLE_SESSION_RESUMPTION +#include +#endif #include #include #include @@ -83,6 +88,9 @@ struct ServerInitParams // Persistent storage delegate: MUST be injected. Used to maintain storage by much common code. // Must be initialized before being provided. PersistentStorageDelegate * persistentStorageDelegate = nullptr; + // Session resumption storage: Optional. Support session resumption when provided. + // Must be initialized before being provided. + SessionResumptionStorage * sessionResumptionStorage = nullptr; // Group data provider: MUST be injected. Used to maintain critical keys such as the Identity // Protection Key (IPK) for CASE. Must be initialized before being provided. Credentials::GroupDataProvider * groupDataProvider = nullptr; @@ -139,6 +147,9 @@ struct CommonCaseDeviceServerInitParams : public ServerInitParams { static chip::KvsPersistentStorageDelegate sKvsPersistenStorageDelegate; static chip::Credentials::GroupDataProviderImpl sGroupDataProvider; +#if CHIP_CONFIG_ENABLE_SESSION_RESUMPTION + static chip::SimpleSessionResumptionStorage sSessionResumptionStorage; +#endif // KVS-based persistent storage delegate injection chip::DeviceLayer::PersistedStorage::KeyValueStoreManager & kvsManager = DeviceLayer::PersistedStorage::KeyValueStoreMgr(); @@ -150,6 +161,13 @@ struct CommonCaseDeviceServerInitParams : public ServerInitParams ReturnErrorOnFailure(sGroupDataProvider.Init()); this->groupDataProvider = &sGroupDataProvider; +#if CHIP_CONFIG_ENABLE_SESSION_RESUMPTION + ReturnErrorOnFailure(sSessionResumptionStorage.Init(&sKvsPersistenStorageDelegate)); + this->sessionResumptionStorage = &sSessionResumptionStorage; +#else + this->sessionResumptionStorage = nullptr; +#endif + // Inject access control delegate this->accessDelegate = Access::Examples::GetAccessControlDelegate(&sKvsPersistenStorageDelegate); @@ -198,6 +216,8 @@ class Server SessionManager & GetSecureSessionManager() { return mSessions; } + SessionResumptionStorage * GetSessionResumptionStorage() { return mSessionResumptionStorage; } + TransportMgrBase & GetTransportManager() { return mTransports; } Credentials::GroupDataProvider * GetGroupDataProvider() { return mGroupsProvider; } @@ -313,6 +333,7 @@ class Server CommissioningWindowManager mCommissioningWindowManager; PersistentStorageDelegate * mDeviceStorage; + SessionResumptionStorage * mSessionResumptionStorage; Credentials::GroupDataProvider * mGroupsProvider; app::DefaultAttributePersistenceProvider mAttributePersister; GroupDataProviderListener mListener; diff --git a/src/app/tests/TestOperationalDeviceProxy.cpp b/src/app/tests/TestOperationalDeviceProxy.cpp index 0a32373f57c443..2aa1699fef54e5 100644 --- a/src/app/tests/TestOperationalDeviceProxy.cpp +++ b/src/app/tests/TestOperationalDeviceProxy.cpp @@ -45,6 +45,7 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, Platform::MemoryInit(); TestTransportMgr transportMgr; SessionManager sessionManager; + SessionResumptionStorage sessionResumptionStorage; ExchangeManager exchangeMgr; Inet::UDPEndPointManagerImpl udpEndPointManager; System::LayerImpl systemLayer; @@ -61,6 +62,7 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, udpEndPointManager.Init(systemLayer); transportMgr.Init(UdpListenParameters(udpEndPointManager).SetAddressType(Inet::IPAddressType::kIPv4).SetListenPort(CHIP_PORT)); sessionManager.Init(&systemLayer, &transportMgr, &messageCounterManager, &deviceStorage); + sessionResumptionStorage.Init(&deviceStorage); exchangeMgr.Init(&sessionManager); messageCounterManager.Init(&exchangeMgr); groupDataProvider.SetPersistentStorage(&deviceStorage); @@ -68,10 +70,11 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, // TODO: Set IPK in groupDataProvider DeviceProxyInitParams params = { - .sessionManager = &sessionManager, - .exchangeMgr = &exchangeMgr, - .fabricInfo = fabric, - .groupDataProvider = &groupDataProvider, + .sessionManager = &sessionManager, + .sessionResumptionStorage = &sessionResumptionStorage, + .exchangeMgr = &exchangeMgr, + .fabricInfo = fabric, + .groupDataProvider = &groupDataProvider, }; NodeId mockNodeId = 1; OperationalDeviceProxy device(params, PeerId().SetNodeId(mockNodeId)); diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index f5e5aab66bb278..3d14de5de624e2 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -36,6 +36,7 @@ #include #include +#include using namespace chip::Inet; using namespace chip::System; @@ -146,13 +147,16 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) #endif )); - stateParams.fabricTable = chip::Platform::New(); - stateParams.sessionMgr = chip::Platform::New(); - stateParams.exchangeMgr = chip::Platform::New(); - stateParams.messageCounterManager = chip::Platform::New(); - stateParams.groupDataProvider = params.groupDataProvider; + stateParams.fabricTable = chip::Platform::New(); + stateParams.sessionMgr = chip::Platform::New(); + SimpleSessionResumptionStorage * sessionResumptionStorage = chip::Platform::New(); + stateParams.sessionResumptionStorage = sessionResumptionStorage; + stateParams.exchangeMgr = chip::Platform::New(); + stateParams.messageCounterManager = chip::Platform::New(); + stateParams.groupDataProvider = params.groupDataProvider; ReturnErrorOnFailure(stateParams.fabricTable->Init(params.fabricIndependentStorage)); + ReturnErrorOnFailure(sessionResumptionStorage->Init(params.fabricIndependentStorage)); auto delegate = chip::Platform::MakeUnique(); ReturnErrorOnFailure(delegate->Init(stateParams.sessionMgr, stateParams.groupDataProvider)); @@ -186,7 +190,7 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) #if CONFIG_NETWORK_LAYER_BLE nullptr, #endif - stateParams.sessionMgr, stateParams.fabricTable, stateParams.groupDataProvider)); + stateParams.sessionMgr, stateParams.fabricTable, stateParams.sessionResumptionStorage, stateParams.groupDataProvider)); // // We need to advertise the port that we're listening to for unsolicited messages over UDP. However, we have both a IPv4 @@ -218,12 +222,13 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) stateParams.caseClientPool = Platform::New(); DeviceProxyInitParams deviceInitParams = { - .sessionManager = stateParams.sessionMgr, - .exchangeMgr = stateParams.exchangeMgr, - .fabricTable = stateParams.fabricTable, - .clientPool = stateParams.caseClientPool, - .groupDataProvider = stateParams.groupDataProvider, - .mrpLocalConfig = Optional::Value(GetLocalMRPConfig()), + .sessionManager = stateParams.sessionMgr, + .sessionResumptionStorage = stateParams.sessionResumptionStorage, + .exchangeMgr = stateParams.exchangeMgr, + .fabricTable = stateParams.fabricTable, + .clientPool = stateParams.caseClientPool, + .groupDataProvider = stateParams.groupDataProvider, + .mrpLocalConfig = Optional::Value(GetLocalMRPConfig()), }; CASESessionManagerConfig sessionManagerConfig = { diff --git a/src/controller/CHIPDeviceControllerSystemState.h b/src/controller/CHIPDeviceControllerSystemState.h index acd00705bed6da..faf30f0f105c0e 100644 --- a/src/controller/CHIPDeviceControllerSystemState.h +++ b/src/controller/CHIPDeviceControllerSystemState.h @@ -81,6 +81,7 @@ struct DeviceControllerSystemStateParams // Params that will be deallocated via Platform::Delete in // DeviceControllerSystemState::Shutdown. DeviceTransportMgr * transportMgr = nullptr; + SessionResumptionStorage * sessionResumptionStorage = nullptr; SessionManager * sessionMgr = nullptr; Messaging::ExchangeManager * exchangeMgr = nullptr; secure_channel::MessageCounterManager * messageCounterManager = nullptr; diff --git a/src/credentials/FabricTable.h b/src/credentials/FabricTable.h index 0c375311f28f53..6960c242bd3284 100644 --- a/src/credentials/FabricTable.h +++ b/src/credentials/FabricTable.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -85,6 +86,8 @@ class DLL_EXPORT FabricInfo } NodeId GetNodeId() const { return mOperationalId.GetNodeId(); } + ScopedNodeId GetScopedNodeId() const { return ScopedNodeId(mOperationalId.GetNodeId(), mFabricIndex); } + ScopedNodeId GetScopedNodeIdForNode(const NodeId node) const { return ScopedNodeId(node, mFabricIndex); } // TODO(#15049): Refactor/rename PeerId to OperationalId or OpId throughout source PeerId GetPeerId() const { return mOperationalId; } PeerId GetPeerIdForNode(const NodeId node) const diff --git a/src/crypto/CHIPCryptoPAL.h b/src/crypto/CHIPCryptoPAL.h index 1dce7388056155..6ff6e6b5606cfe 100644 --- a/src/crypto/CHIPCryptoPAL.h +++ b/src/crypto/CHIPCryptoPAL.h @@ -206,6 +206,18 @@ class CapacityBoundBuffer ClearSecretData(&bytes[0], Cap); } + CapacityBoundBuffer & operator=(const CapacityBoundBuffer & other) + { + // Guard self assignment + if (this == &other) + return *this; + + ClearSecretData(&bytes[0], Cap); + SetLength(other.Length()); + ::memcpy(Bytes(), other.Bytes(), other.Length()); + return *this; + } + /** @brief Set current length of the buffer that's being used * @return Returns error if new length is > capacity **/ diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h index 50c4f975cbf238..b5a7d0ca1b5377 100644 --- a/src/lib/core/CHIPConfig.h +++ b/src/lib/core/CHIPConfig.h @@ -1669,7 +1669,7 @@ extern const char CHIP_NON_PRODUCTION_MARKER[]; * Maximum number of CASE sessions that a device caches, that can be resumed */ #ifndef CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE -#define CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE 4 +#define CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE 64 #endif /** diff --git a/src/lib/support/DefaultStorageKeyAllocator.h b/src/lib/support/DefaultStorageKeyAllocator.h index 48b97e766b496e..160b79a560971f 100644 --- a/src/lib/support/DefaultStorageKeyAllocator.h +++ b/src/lib/support/DefaultStorageKeyAllocator.h @@ -52,6 +52,14 @@ class DefaultStorageKeyAllocator // FailSafeContext const char * FailSafeContextKey() { return Format("g/fsc"); } + // Session resumption + const char * FabricSession(FabricIndex fabric, NodeId nodeId) + { + return Format("f/%x/s/%08" PRIX32 "%08" PRIX32, fabric, static_cast(nodeId >> 32), static_cast(nodeId)); + } + const char * SessionResumptionIndex() { return Format("f/sri"); } + const char * SessionResumption(const char * resumptionIdBase64) { return Format("s/%s", resumptionIdBase64); } + // Access Control const char * AccessControlExtensionEntry(FabricIndex fabric) { return Format("f/%x/ac/1", fabric); } diff --git a/src/lib/support/Span.h b/src/lib/support/Span.h index a5dabe9084c74d..4d8f521eb35c0c 100644 --- a/src/lib/support/Span.h +++ b/src/lib/support/Span.h @@ -38,8 +38,7 @@ template class Span { public: - using pointer = T *; - using const_pointer = const T *; + using pointer = T *; constexpr Span() : mDataBuf(nullptr), mDataLen(0) {} constexpr Span(pointer databuf, size_t datalen) : mDataBuf(databuf), mDataLen(datalen) {} @@ -73,10 +72,8 @@ class Span constexpr pointer data() const { return mDataBuf; } constexpr size_t size() const { return mDataLen; } constexpr bool empty() const { return size() == 0; } - constexpr const_pointer begin() const { return data(); } - constexpr const_pointer end() const { return data() + size(); } - constexpr pointer begin() { return data(); } - constexpr pointer end() { return data() + size(); } + constexpr pointer begin() const { return data(); } + constexpr pointer end() const { return data() + size(); } template , std::remove_const_t>::value>> bool data_equal(const Span & other) const @@ -145,8 +142,7 @@ template class FixedSpan { public: - using pointer = T *; - using const_pointer = const T *; + using pointer = T *; constexpr FixedSpan() : mDataBuf(nullptr) {} @@ -189,11 +185,8 @@ class FixedSpan constexpr pointer data() const { return mDataBuf; } constexpr size_t size() const { return N; } constexpr bool empty() const { return data() == nullptr; } - - constexpr pointer begin() { return mDataBuf; } - constexpr pointer end() { return mDataBuf + N; } - constexpr const_pointer begin() const { return mDataBuf; } - constexpr const_pointer end() const { return mDataBuf + N; } + constexpr pointer begin() const { return mDataBuf; } + constexpr pointer end() const { return mDataBuf + N; } // Allow data_equal for spans that are over the same type up to const-ness. template , std::remove_const_t>::value>> diff --git a/src/protocols/secure_channel/BUILD.gn b/src/protocols/secure_channel/BUILD.gn index 4cc84703343fcd..e088b6e5a23c1c 100644 --- a/src/protocols/secure_channel/BUILD.gn +++ b/src/protocols/secure_channel/BUILD.gn @@ -10,14 +10,16 @@ static_library("secure_channel") { "CASEServer.h", "CASESession.cpp", "CASESession.h", - "CASESessionCache.cpp", - "CASESessionCache.h", "PASESession.cpp", "PASESession.h", "RendezvousParameters.h", "SessionEstablishmentDelegate.h", "SessionEstablishmentExchangeDispatch.cpp", "SessionEstablishmentExchangeDispatch.h", + "SessionResumptionStorage.cpp", + "SessionResumptionStorage.h", + "SimpleSessionResumptionStorage.cpp", + "SimpleSessionResumptionStorage.h", "StatusReport.cpp", "StatusReport.h", ] diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 470f6cae3aaf45..19ec85d724ca0e 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -34,6 +34,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager Ble::BleLayer * bleLayer, #endif SessionManager * sessionManager, FabricTable * fabrics, + SessionResumptionStorage * sessionResumptionStorage, Credentials::GroupDataProvider * responderGroupDataProvider) { VerifyOrReturnError(transportMgr != nullptr, CHIP_ERROR_INVALID_ARGUMENT); @@ -45,10 +46,11 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager #if CONFIG_NETWORK_LAYER_BLE mBleLayer = bleLayer; #endif - mSessionManager = sessionManager; - mFabrics = fabrics; - mExchangeManager = exchangeManager; - mGroupDataProvider = responderGroupDataProvider; + mSessionManager = sessionManager; + mSessionResumptionStorage = sessionResumptionStorage; + mFabrics = fabrics; + mExchangeManager = exchangeManager; + mGroupDataProvider = responderGroupDataProvider; Cleanup(); return CHIP_NO_ERROR; @@ -76,8 +78,9 @@ CHIP_ERROR CASEServer::InitCASEHandshake(Messaging::ExchangeContext * ec) // Setup CASE state machine using the credentials for the current fabric. GetSession().SetGroupDataProvider(mGroupDataProvider); - ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment( - *mSessionManager, mFabrics, this, Optional::Value(GetLocalMRPConfig()))); + ReturnErrorOnFailure( + GetSession().ListenForSessionEstablishment(*mSessionManager, mFabrics, mSessionResumptionStorage, this, + Optional::Value(GetLocalMRPConfig()))); // Hand over the exchange context to the CASE session. ec->SetDelegate(&GetSession()); diff --git a/src/protocols/secure_channel/CASEServer.h b/src/protocols/secure_channel/CASEServer.h index 6e93f558a88ee3..231cc0cc85ae7f 100644 --- a/src/protocols/secure_channel/CASEServer.h +++ b/src/protocols/secure_channel/CASEServer.h @@ -44,6 +44,7 @@ class CASEServer : public SessionEstablishmentDelegate, public Messaging::Exchan Ble::BleLayer * bleLayer, #endif SessionManager * sessionManager, FabricTable * fabrics, + SessionResumptionStorage * sessionResumptionStorage, Credentials::GroupDataProvider * responderGroupDataProvider); //////////// SessionEstablishmentDelegate Implementation /////////////// @@ -59,7 +60,8 @@ class CASEServer : public SessionEstablishmentDelegate, public Messaging::Exchan virtual CASESession & GetSession() { return mPairingSession; } private: - Messaging::ExchangeManager * mExchangeManager = nullptr; + Messaging::ExchangeManager * mExchangeManager = nullptr; + SessionResumptionStorage * mSessionResumptionStorage = nullptr; CASESession mPairingSession; SessionManager * mSessionManager = nullptr; diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index a3783a977be8b9..72dd67f8e46483 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -138,55 +139,6 @@ void CASESession::DiscardExchange() } } -CHIP_ERROR CASESession::ToCachable(CASESessionCachable & cachableSession) -{ - const NodeId peerNodeId = GetPeerNodeId(); - VerifyOrReturnError(CanCastTo(mSharedSecret.Length()), CHIP_ERROR_INTERNAL); - VerifyOrReturnError(CanCastTo(peerNodeId), CHIP_ERROR_INTERNAL); - - cachableSession.mSharedSecretLen = LittleEndian::HostSwap16(static_cast(mSharedSecret.Length())); - cachableSession.mPeerNodeId = LittleEndian::HostSwap64(peerNodeId); - for (size_t i = 0; i < cachableSession.mPeerCATs.size(); i++) - { - cachableSession.mPeerCATs.values[i] = LittleEndian::HostSwap32(GetPeerCATs().values[i]); - } - cachableSession.mLocalFabricIndex = (mFabricInfo != nullptr) ? mFabricInfo->GetFabricIndex() : kUndefinedFabricIndex; - cachableSession.mSessionSetupTimeStamp = LittleEndian::HostSwap64(mSessionSetupTimeStamp); - - memcpy(cachableSession.mResumptionId, mResumptionId, sizeof(mResumptionId)); - memcpy(cachableSession.mSharedSecret, mSharedSecret, mSharedSecret.Length()); - memcpy(cachableSession.mIPK, mIPK, sizeof(mIPK)); - - return CHIP_NO_ERROR; -} - -CHIP_ERROR CASESession::FromCachable(const CASESessionCachable & cachableSession) -{ - uint16_t length = LittleEndian::HostSwap16(cachableSession.mSharedSecretLen); - ReturnErrorOnFailure(mSharedSecret.SetLength(static_cast(length))); - memset(mSharedSecret, 0, sizeof(mSharedSecret.Capacity())); - memcpy(mSharedSecret, cachableSession.mSharedSecret, length); - - SetPeerNodeId(LittleEndian::HostSwap64(cachableSession.mPeerNodeId)); - CATValues peerCATs; - for (size_t i = 0; i < cachableSession.mPeerCATs.size(); i++) - { - peerCATs.values[i] = LittleEndian::HostSwap32(cachableSession.mPeerCATs.values[i]); - } - SetPeerCATs(peerCATs); - SetSessionTimeStamp(LittleEndian::HostSwap64(cachableSession.mSessionSetupTimeStamp)); - mLocalFabricIndex = cachableSession.mLocalFabricIndex; - - memcpy(mResumptionId, cachableSession.mResumptionId, sizeof(mResumptionId)); - - // TODO: Handle data dependency between IPK caching and the possible underlying changes of that IPK - memcpy(mIPK, cachableSession.mIPK, sizeof(mIPK)); - - mCASESessionEstablished = true; - - return CHIP_NO_ERROR; -} - CHIP_ERROR CASESession::Init(SessionManager & sessionManager, SessionEstablishmentDelegate * delegate) { VerifyOrReturnError(delegate != nullptr, CHIP_ERROR_INVALID_ARGUMENT); @@ -208,14 +160,16 @@ CHIP_ERROR CASESession::Init(SessionManager & sessionManager, SessionEstablishme CHIP_ERROR CASESession::ListenForSessionEstablishment(SessionManager & sessionManager, FabricTable * fabrics, + SessionResumptionStorage * sessionResumptionStorage, SessionEstablishmentDelegate * delegate, Optional mrpConfig) { VerifyOrReturnError(fabrics != nullptr, CHIP_ERROR_INVALID_ARGUMENT); ReturnErrorOnFailure(Init(sessionManager, delegate)); - mFabricsTable = fabrics; - mLocalMRPConfig = mrpConfig; + mFabricsTable = fabrics; + mSessionResumptionStorage = sessionResumptionStorage; + mLocalMRPConfig = mrpConfig; mCASESessionEstablished = false; @@ -226,6 +180,7 @@ CASESession::ListenForSessionEstablishment(SessionManager & sessionManager, Fabr CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, FabricInfo * fabric, NodeId peerNodeId, ExchangeContext * exchangeCtxt, + SessionResumptionStorage * sessionResumptionStorage, SessionEstablishmentDelegate * delegate, Optional mrpConfig) { MATTER_TRACE_EVENT_SCOPE("EstablishSession", "CASESession"); @@ -351,7 +306,7 @@ CHIP_ERROR CASESession::SendSigma1() kSHA256_Hash_Length, // destinationId kP256_PublicKey_Length, // InitiatorEphPubKey, mrpParamsSize, // initiatorMRPParams - kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES); + SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES); System::PacketBufferTLVWriter tlvWriter; System::PacketBufferHandle msg_R1; @@ -404,19 +359,26 @@ CHIP_ERROR CASESession::SendSigma1() ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter)); } - // If CASE session was previously established using the current state information, let's fill in the session resumption - // information in the the Sigma1 request. It'll speed up the session establishment process if the peer can resume the old - // session, since no certificate chains will have to be verified. - if (mCASESessionEstablished) + // Try to find persistent session, and resume it. + bool resuming = false; + if (mSessionResumptionStorage != nullptr) { - ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(6), mResumptionId, kCASEResumptionIDSize)); + SessionResumptionStorage::ResumptionIdStorage resumptionId; + CHIP_ERROR err = mSessionResumptionStorage->FindByScopedNodeId(mFabricInfo->GetScopedNodeIdForNode(GetPeerNodeId()), + resumptionId, mSharedSecret, mPeerCATs); + if (err == CHIP_NO_ERROR) + { + // Found valid resumption state, try to resume the session. + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(6), resumptionId)); - uint8_t initiatorResume1MIC[CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES]; - MutableByteSpan resumeMICSpan(initiatorResume1MIC); - ReturnErrorOnFailure(GenerateSigmaResumeMIC(ByteSpan(mInitiatorRandom), ByteSpan(mResumptionId), ByteSpan(kKDFS1RKeyInfo), - ByteSpan(kResume1MIC_Nonce), resumeMICSpan)); + uint8_t initiatorResume1MIC[CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES]; + MutableByteSpan resumeMICSpan(initiatorResume1MIC); + ReturnErrorOnFailure(GenerateSigmaResumeMIC(ByteSpan(mInitiatorRandom), ByteSpan(resumptionId), + ByteSpan(kKDFS1RKeyInfo), ByteSpan(kResume1MIC_Nonce), resumeMICSpan)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(7), resumeMICSpan)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(7), resumeMICSpan)); + resuming = true; + } } ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); @@ -428,7 +390,7 @@ CHIP_ERROR CASESession::SendSigma1() ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1), SendFlags(SendMessageFlags::kExpectResponse))); - mState = kSentSigma1; + mState = resuming ? kSentSigma1Resume : kSentSigma1; ChipLogProgress(SecureChannel, "Sent Sigma1 msg"); @@ -497,6 +459,29 @@ CHIP_ERROR CASESession::FindLocalNodeFromDestionationId(const ByteSpan & destina return found ? CHIP_NO_ERROR : CHIP_ERROR_KEY_NOT_FOUND; } +CHIP_ERROR CASESession::TryResumeSession(SessionResumptionStorage::ConstResumptionIdView resumptionId, ByteSpan resume1MIC, + ByteSpan initiatorRandom) +{ + if (mSessionResumptionStorage == nullptr) + return CHIP_ERROR_INCORRECT_STATE; + + SessionResumptionStorage::ConstResumptionIdView resumptionIdSpan(resumptionId); + ScopedNodeId node; + ReturnErrorOnFailure(mSessionResumptionStorage->FindByResumptionId(resumptionIdSpan, node, mSharedSecret, mPeerCATs)); + + // Cross check resume1MIC with the shared secret + ReturnErrorOnFailure( + ValidateSigmaResumeMIC(resume1MIC, initiatorRandom, resumptionId, ByteSpan(kKDFS1RKeyInfo), ByteSpan(kResume1MIC_Nonce))); + + mFabricInfo = mFabricsTable->FindFabricWithIndex(node.GetFabricIndex()); + if (mFabricInfo == nullptr) + return CHIP_ERROR_INTERNAL; + + mPeerNodeId = node.GetNodeId(); + + return CHIP_NO_ERROR; +} + CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg) { MATTER_TRACE_EVENT_SCOPE("HandleSigma1", "CASESession"); @@ -523,20 +508,19 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg) ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", initiatorSessionId); SetPeerSessionId(initiatorSessionId); - if (sessionResumptionRequested && resumptionId.data_equal(ByteSpan(mResumptionId))) + VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE); + + if (sessionResumptionRequested && resumptionId.size() == SessionResumptionStorage::kResumptionIdSize && + CHIP_NO_ERROR == + TryResumeSession(SessionResumptionStorage::ConstResumptionIdView(resumptionId.data()), resume1MIC, initiatorRandom)) { - // Cross check resume1MIC with the shared secret - if (ValidateSigmaResumeMIC(resume1MIC, initiatorRandom, resumptionId, ByteSpan(kKDFS1RKeyInfo), - ByteSpan(kResume1MIC_Nonce)) == CHIP_NO_ERROR) - { - // Send Sigma2Resume message to the initiator - SuccessOrExit(err = SendSigma2Resume(initiatorRandom)); + // Send Sigma2Resume message to the initiator + SuccessOrExit(err = SendSigma2Resume(initiatorRandom)); - mDelegate->OnSessionEstablishmentStarted(); + mDelegate->OnSessionEstablishmentStarted(); - // Early returning here, since we have sent Sigma2Resume, and no further processing is needed for the Sigma1 message - return CHIP_NO_ERROR; - } + // Early returning here, since we have sent Sigma2Resume, and no further processing is needed for the Sigma1 message + return CHIP_NO_ERROR; } err = FindLocalNodeFromDestionationId(destinationIdentifier, initiatorRandom); @@ -578,8 +562,8 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) { MATTER_TRACE_EVENT_SCOPE("SendSigma2Resume", "CASESession"); const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; - size_t max_sigma2_resume_data_len = - TLV::EstimateStructOverhead(kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, sizeof(uint16_t), mrpParamsSize); + size_t max_sigma2_resume_data_len = TLV::EstimateStructOverhead( + SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, sizeof(uint16_t), mrpParamsSize); System::PacketBufferTLVWriter tlvWriter; System::PacketBufferHandle msg_R2_resume; @@ -594,14 +578,14 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) tlvWriter.Init(std::move(msg_R2_resume)); // Generate a new resumption ID - ReturnErrorOnFailure(DRBG_get_bytes(mResumptionId, sizeof(mResumptionId))); + ReturnErrorOnFailure(DRBG_get_bytes(mResumptionId.data(), mResumptionId.size())); ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(mResumptionId))); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), mResumptionId)); uint8_t sigma2ResumeMIC[CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES]; MutableByteSpan resumeMICSpan(sigma2ResumeMIC); - ReturnErrorOnFailure(GenerateSigmaResumeMIC(initiatorRandom, ByteSpan(mResumptionId), ByteSpan(kKDFS2RKeyInfo), + ReturnErrorOnFailure(GenerateSigmaResumeMIC(initiatorRandom, mResumptionId, ByteSpan(kKDFS2RKeyInfo), ByteSpan(kResume2MIC_Nonce), resumeMICSpan)); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), resumeMICSpan)); @@ -685,8 +669,8 @@ CHIP_ERROR CASESession::SendSigma2() msg_R2_Signed.Free(); // Construct Sigma2 TBE Data - size_t msg_r2_signed_enc_len = - TLV::EstimateStructOverhead(nocCert.size(), icaCert.size(), tbsData2Signature.Length(), kCASEResumptionIDSize); + size_t msg_r2_signed_enc_len = TLV::EstimateStructOverhead(nocCert.size(), icaCert.size(), tbsData2Signature.Length(), + SessionResumptionStorage::kResumptionIdSize); chip::Platform::ScopedMemoryBuffer msg_R2_Encrypted; VerifyOrReturnError(msg_R2_Encrypted.Alloc(msg_r2_signed_enc_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES), CHIP_ERROR_NO_MEMORY); @@ -705,9 +689,8 @@ CHIP_ERROR CASESession::SendSigma2() static_cast(tbsData2Signature.Length()))); // Generate a new resumption ID - ReturnErrorOnFailure(DRBG_get_bytes(mResumptionId, sizeof(mResumptionId))); - ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_ResumptionID), mResumptionId, - static_cast(sizeof(mResumptionId)))); + ReturnErrorOnFailure(DRBG_get_bytes(mResumptionId.data(), mResumptionId.size())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_ResumptionID), mResumptionId)); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize()); @@ -780,15 +763,16 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg) SuccessOrExit(err = tlvReader.Next()); VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); - VerifyOrExit(tlvReader.GetLength() == kCASEResumptionIDSize, err = CHIP_ERROR_INVALID_TLV_ELEMENT); - SuccessOrExit(err = tlvReader.GetBytes(mResumptionId, kCASEResumptionIDSize)); + SessionResumptionStorage::ResumptionIdStorage resumptionId; + VerifyOrExit(tlvReader.GetLength() == resumptionId.size(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); + SuccessOrExit(err = tlvReader.GetBytes(resumptionId.data(), resumptionId.size())); SuccessOrExit(err = tlvReader.Next()); VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); VerifyOrExit(tlvReader.GetLength() == CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, err = CHIP_ERROR_INVALID_TLV_ELEMENT); SuccessOrExit(err = tlvReader.GetBytes(sigma2ResumeMIC, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); - SuccessOrExit(err = ValidateSigmaResumeMIC(ByteSpan(sigma2ResumeMIC), ByteSpan(mInitiatorRandom), ByteSpan(mResumptionId), + SuccessOrExit(err = ValidateSigmaResumeMIC(ByteSpan(sigma2ResumeMIC), ByteSpan(mInitiatorRandom), resumptionId, ByteSpan(kKDFS2RKeyInfo), ByteSpan(kResume2MIC_Nonce))); SuccessOrExit(err = tlvReader.Next()); @@ -803,6 +787,14 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg) ChipLogDetail(SecureChannel, "Peer assigned session session ID %d", responderSessionId); SetPeerSessionId(responderSessionId); + if (mSessionResumptionStorage != nullptr) + { + CHIP_ERROR err2 = mSessionResumptionStorage->Save(ScopedNodeId(GetPeerNodeId(), GetFabricIndex()), resumptionId, + mSharedSecret, mPeerCATs); + if (err2 != CHIP_NO_ERROR) + ChipLogError(SecureChannel, "Unable to save session resumption state: %" CHIP_ERROR_FORMAT, err2.Format()); + } + SendStatusReport(mExchangeCtxt, kProtocolCodeSuccess); // TODO: Set timestamp on the new session, to allow selecting a least-recently-used session for eviction @@ -971,14 +963,10 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) // Retrieve session resumption ID SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_ResumptionID))); - SuccessOrExit(err = decryptedDataTlvReader.GetBytes(mResumptionId, static_cast(sizeof(mResumptionId)))); + SuccessOrExit(err = decryptedDataTlvReader.GetBytes(mResumptionId.data(), mResumptionId.size())); // Retrieve peer CASE Authenticated Tags (CATs) from peer's NOC. - { - CATValues peerCATs; - SuccessOrExit(err = ExtractCATsFromOpCert(responderNOC, peerCATs)); - SetPeerCATs(peerCATs); - } + SuccessOrExit(err = ExtractCATsFromOpCert(responderNOC, mPeerCATs)); // Retrieve responderMRPParams if present if (tlvReader.Next() != CHIP_END_OF_TLV) @@ -1242,9 +1230,15 @@ CHIP_ERROR CASESession::HandleSigma3(System::PacketBufferHandle && msg) // Retrieve peer CASE Authenticated Tags (CATs) from peer's NOC. { - CATValues peerCATs; - SuccessOrExit(err = ExtractCATsFromOpCert(initiatorNOC, peerCATs)); - SetPeerCATs(peerCATs); + SuccessOrExit(err = ExtractCATsFromOpCert(initiatorNOC, mPeerCATs)); + } + + if (mSessionResumptionStorage != nullptr) + { + CHIP_ERROR err2 = mSessionResumptionStorage->Save(ScopedNodeId(GetPeerNodeId(), GetFabricIndex()), mResumptionId, + mSharedSecret, mPeerCATs); + if (err2 != CHIP_NO_ERROR) + ChipLogError(SecureChannel, "Unable to save session resumption state: %" CHIP_ERROR_FORMAT, err2.Format()); } SendStatusReport(mExchangeCtxt, kProtocolCodeSuccess); @@ -1314,7 +1308,7 @@ CHIP_ERROR CASESession::ConstructSigmaResumeKey(const ByteSpan & initiatorRandom { VerifyOrReturnError(resumeKey.size() >= CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES, CHIP_ERROR_BUFFER_TOO_SMALL); - constexpr size_t saltSize = kSigmaParamRandomNumberSize + kCASEResumptionIDSize; + constexpr size_t saltSize = kSigmaParamRandomNumberSize + SessionResumptionStorage::kResumptionIdSize; uint8_t salt[saltSize]; memset(salt, 0, saltSize); @@ -1454,6 +1448,14 @@ void CASESession::OnSuccessStatusReport() ChipLogProgress(SecureChannel, "Success status report received. Session was established"); mCASESessionEstablished = true; + if (mSessionResumptionStorage != nullptr) + { + CHIP_ERROR err2 = mSessionResumptionStorage->Save(ScopedNodeId(GetPeerNodeId(), GetFabricIndex()), mResumptionId, + mSharedSecret, mPeerCATs); + if (err2 != CHIP_NO_ERROR) + ChipLogError(SecureChannel, "Unable to save session resumption state: %" CHIP_ERROR_FORMAT, err2.Format()); + } + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. DiscardExchange(); @@ -1538,7 +1540,7 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, { resumptionIDTagFound = true; ReturnErrorOnFailure(tlvReader.GetByteView(resumptionId)); - VerifyOrReturnError(resumptionId.size() == kCASEResumptionIDSize, CHIP_ERROR_INVALID_CASE_PARAMETER); + VerifyOrReturnError(resumptionId.size() == SessionResumptionStorage::kResumptionIdSize, CHIP_ERROR_INVALID_CASE_PARAMETER); err = tlvReader.Next(); } @@ -1626,6 +1628,22 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea err = HandleSigma2_and_SendSigma3(std::move(msg)); break; + case MsgType::StatusReport: + err = HandleStatusReport(std::move(msg), /* successExpected*/ false); + break; + + default: + // Return the default error that was set above + break; + }; + break; + case kSentSigma1Resume: + switch (static_cast(payloadHeader.GetMessageType())) + { + case Protocols::SecureChannel::MsgType::CASE_Sigma2: + err = HandleSigma2_and_SendSigma3(std::move(msg)); + break; + case Protocols::SecureChannel::MsgType::CASE_Sigma2Resume: err = HandleSigma2Resume(std::move(msg)); break; diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index a6651b1f1fb9a6..6971cf06bdd824 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -48,24 +49,10 @@ namespace chip { -constexpr size_t kCASEResumptionIDSize = 16; - #ifdef ENABLE_HSM_CASE_EPHEMERAL_KEY #define CASE_EPHEMERAL_KEY 0xCA5EECD0 #endif -struct CASESessionCachable -{ - uint16_t mSharedSecretLen = 0; - uint8_t mSharedSecret[Crypto::kMax_ECDH_Secret_Length] = { 0 }; - FabricIndex mLocalFabricIndex = 0; - NodeId mPeerNodeId = kUndefinedNodeId; - CATValues mPeerCATs; - uint8_t mResumptionId[kCASEResumptionIDSize] = { 0 }; - uint64_t mSessionSetupTimeStamp = 0; - uint8_t mIPK[kIPKSize] = { 0 }; -}; - class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public PairingSession { public: @@ -86,7 +73,8 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR ListenForSessionEstablishment( - SessionManager & sessionManager, FabricTable * fabrics, SessionEstablishmentDelegate * delegate, + SessionManager & sessionManager, FabricTable * fabrics, SessionResumptionStorage * sessionResumptionStorage, + SessionEstablishmentDelegate * delegate, Optional mrpConfig = Optional::Missing()); /** @@ -104,7 +92,8 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin */ CHIP_ERROR EstablishSession(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, FabricInfo * fabric, - NodeId peerNodeId, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate, + NodeId peerNodeId, Messaging::ExchangeContext * exchangeCtxt, + SessionResumptionStorage * sessionResumptionStorage, SessionEstablishmentDelegate * delegate, Optional mrpConfig = Optional::Missing()); /** @@ -152,16 +141,6 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin */ CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override; - /** - * @brief Serialize the CASESession to the given cachableSession data structure for secure pairing - **/ - CHIP_ERROR ToCachable(CASESessionCachable & output); - - /** - * @brief Reconstruct secure pairing class from the cachableSession data structure. - **/ - CHIP_ERROR FromCachable(const CASESessionCachable & output); - //// ExchangeDelegate Implementation //// CHIP_ERROR OnMessageReceived(Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader, System::PacketBufferHandle && payload) override; @@ -175,11 +154,6 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin **/ void Clear(); - /** - * Parse the TLV for Sigma1 message. - */ - CHIP_ERROR ParseSigma1(); - private: enum State : uint8_t { @@ -187,7 +161,8 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin kSentSigma1 = 1, kSentSigma2 = 2, kSentSigma3 = 3, - kSentSigma2Resume = 4, + kSentSigma1Resume = 4, + kSentSigma2Resume = 5, }; CHIP_ERROR Init(SessionManager & sessionManager, SessionEstablishmentDelegate * delegate); @@ -201,6 +176,8 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin CHIP_ERROR SendSigma1(); CHIP_ERROR HandleSigma1_and_SendSigma2(System::PacketBufferHandle && msg); CHIP_ERROR HandleSigma1(System::PacketBufferHandle && msg); + CHIP_ERROR TryResumeSession(SessionResumptionStorage::ConstResumptionIdView resumptionId, ByteSpan resume1MIC, + ByteSpan initiatorRandom); CHIP_ERROR SendSigma2(); CHIP_ERROR HandleSigma2_and_SendSigma3(System::PacketBufferHandle && msg); CHIP_ERROR HandleSigma2(System::PacketBufferHandle && msg); @@ -260,26 +237,23 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin uint8_t mMessageDigest[Crypto::kSHA256_Hash_Length]; uint8_t mIPK[kIPKSize]; - Messaging::ExchangeContext * mExchangeCtxt = nullptr; + Messaging::ExchangeContext * mExchangeCtxt = nullptr; + SessionResumptionStorage * mSessionResumptionStorage = nullptr; FabricTable * mFabricsTable = nullptr; const FabricInfo * mFabricInfo = nullptr; - uint8_t mResumptionId[kCASEResumptionIDSize]; + // This field is only used for CASE responder, when during sending sigma2 and waiting for sigma3 + SessionResumptionStorage::ResumptionIdStorage mResumptionId; // Sigma1 initiator random, maintained to be reused post-Sigma1, such as when generating Sigma2 S2RK key uint8_t mInitiatorRandom[kSigmaParamRandomNumberSize]; State mState; - uint8_t mLocalFabricIndex = 0; - uint64_t mSessionSetupTimeStamp = 0; - Optional mLocalMRPConfig; protected: bool mCASESessionEstablished = false; - - void SetSessionTimeStamp(uint64_t timestamp) { mSessionSetupTimeStamp = timestamp; } }; } // namespace chip diff --git a/src/protocols/secure_channel/CASESessionCache.cpp b/src/protocols/secure_channel/CASESessionCache.cpp deleted file mode 100644 index 74c1662ff09e32..00000000000000 --- a/src/protocols/secure_channel/CASESessionCache.cpp +++ /dev/null @@ -1,105 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace chip { - -CASESessionCache::CASESessionCache() {} - -CASESessionCache::~CASESessionCache() -{ - mCachePool.ForEachActiveObject([&](auto * ec) { - mCachePool.ReleaseObject(ec); - return Loop::Continue; - }); -} - -CASESessionCachable * CASESessionCache::GetLRUSession() -{ - uint64_t minTimeStamp = UINT64_MAX; - CASESessionCachable * lruSession = nullptr; - mCachePool.ForEachActiveObject([&](auto * ec) { - if (minTimeStamp > ec->mSessionSetupTimeStamp) - { - minTimeStamp = ec->mSessionSetupTimeStamp; - lruSession = ec; - } - return Loop::Continue; - }); - return lruSession; -} - -CHIP_ERROR CASESessionCache::Add(CASESessionCachable & cachableSession) -{ - // It's not an error if a device doesn't have cache for storing the sessions. - VerifyOrReturnError(mCachePool.Capacity() > 0, CHIP_NO_ERROR); - - // If the cache is full, get the least recently used session index and release that. - if (mCachePool.Allocated() >= kCacheSize) - { - mCachePool.ReleaseObject(GetLRUSession()); - } - - mCachePool.CreateObject(cachableSession); - return CHIP_NO_ERROR; -} - -CHIP_ERROR CASESessionCache::Remove(ResumptionID resumptionID) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - CASESession session; - mCachePool.ForEachActiveObject([&](auto * ec) { - if (resumptionID.data_equal(ResumptionID(ec->mResumptionId))) - { - mCachePool.ReleaseObject(ec); - } - return Loop::Continue; - }); - - return err; -} - -CHIP_ERROR CASESessionCache::Get(ResumptionID resumptionID, CASESessionCachable & outSessionCachable) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - bool found = false; - mCachePool.ForEachActiveObject([&](auto * ec) { - if (resumptionID.data_equal(ResumptionID(ec->mResumptionId))) - { - found = true; - outSessionCachable = *ec; - return Loop::Break; - } - return Loop::Continue; - }); - - if (!found) - { - err = CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND; - } - - return err; -} - -CHIP_ERROR CASESessionCache::Get(const PeerId & peer, CASESessionCachable & outSessionCachable) -{ - // TODO: Implement this based on peer id - return CHIP_NO_ERROR; -} - -} // namespace chip diff --git a/src/protocols/secure_channel/CASESessionCache.h b/src/protocols/secure_channel/CASESessionCache.h deleted file mode 100644 index f4bbca22c431e1..00000000000000 --- a/src/protocols/secure_channel/CASESessionCache.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace chip { - -using ResumptionID = FixedByteSpan; - -class CASESessionCache -{ -public: - CASESessionCache(); - virtual ~CASESessionCache(); - - CHIP_ERROR Add(CASESessionCachable & cachableSession); - CHIP_ERROR Remove(ResumptionID resumptionID); - CHIP_ERROR Get(ResumptionID resumptionID, CASESessionCachable & outCachableSession); - CHIP_ERROR Get(const PeerId & peer, CASESessionCachable & outCachableSession); - -private: - static constexpr size_t kCacheSize = CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE; - ObjectPool mCachePool; - CASESessionCachable * GetLRUSession(); -}; - -} // namespace chip diff --git a/src/protocols/secure_channel/SessionResumptionStorage.cpp b/src/protocols/secure_channel/SessionResumptionStorage.cpp new file mode 100644 index 00000000000000..368cc1930fd186 --- /dev/null +++ b/src/protocols/secure_channel/SessionResumptionStorage.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP CASE Session object that provides + * APIs for constructing a secure session using a certificate from the device's + * operational credentials. + */ + +#include + +#include +#include + +namespace chip { + +CHIP_ERROR SessionResumptionStorage::FindByScopedNodeId(const ScopedNodeId & node, ResumptionIdStorage & resumptionId, + Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs) +{ + ReturnErrorOnFailure(LoadState(node, resumptionId, sharedSecret, peerCATs)); + return CHIP_NO_ERROR; +} + +CHIP_ERROR SessionResumptionStorage::FindByResumptionId(ConstResumptionIdView resumptionId, ScopedNodeId & node, + Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs) +{ + ReturnErrorOnFailure(FindNodeByResumptionId(resumptionId, node)); + ResumptionIdStorage tmpResumptionId; + ReturnErrorOnFailure(FindByScopedNodeId(node, tmpResumptionId, sharedSecret, peerCATs)); + VerifyOrReturnError(std::equal(tmpResumptionId.begin(), tmpResumptionId.end(), resumptionId.begin(), resumptionId.end()), + CHIP_ERROR_KEY_NOT_FOUND); + return CHIP_NO_ERROR; +} + +CHIP_ERROR SessionResumptionStorage::FindNodeByResumptionId(ConstResumptionIdView resumptionId, ScopedNodeId & node) +{ + ReturnErrorOnFailure(LoadLink(resumptionId, node)); + return CHIP_NO_ERROR; +} + +CHIP_ERROR SessionResumptionStorage::Save(const ScopedNodeId & node, ConstResumptionIdView resumptionId, + const Crypto::P256ECDHDerivedSecret & sharedSecret, const CATValues & peerCATs) +{ + SessionIndex index; + ReturnErrorOnFailure(LoadIndex(index)); + + if (index.mSize == CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE) + { + // TODO: implement LRU for resumption + ReturnErrorOnFailure(Delete(index.mNodes[0])); + ReturnErrorOnFailure(LoadIndex(index)); + } + + ReturnErrorOnFailure(SaveState(node, resumptionId, sharedSecret, peerCATs)); + ReturnErrorOnFailure(SaveLink(resumptionId, node)); + + index.mNodes[index.mSize++] = node; + ReturnErrorOnFailure(SaveIndex(index)); + + return CHIP_NO_ERROR; +} + +CHIP_ERROR SessionResumptionStorage::Delete(const ScopedNodeId & node) +{ + SessionIndex index; + ReturnErrorOnFailure(LoadIndex(index)); + + ResumptionIdStorage resumptionId; + Crypto::P256ECDHDerivedSecret sharedSecret; + CATValues peerCATs; + CHIP_ERROR err = LoadState(node, resumptionId, sharedSecret, peerCATs); + if (err == CHIP_NO_ERROR) + { + err = DeleteLink(resumptionId); + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, + "Unable to delete session resumption link for node " ChipLogFormatX64 ": %" CHIP_ERROR_FORMAT, + ChipLogValueX64(node.GetNodeId()), err.Format()); + } + } + else + { + ChipLogError(SecureChannel, + "Unable to load session resumption state during session deletion for node " ChipLogFormatX64 + ": %" CHIP_ERROR_FORMAT, + ChipLogValueX64(node.GetNodeId()), err.Format()); + } + + err = DeleteState(node); + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Unable to delete session resumption state for node " ChipLogFormatX64 ": %" CHIP_ERROR_FORMAT, + ChipLogValueX64(node.GetNodeId()), err.Format()); + } + + bool found = false; + for (size_t i = 0; i < index.mSize; ++i) + { + if (found) + { + index.mNodes[i] = index.mNodes[i + 1]; + } + else + { + if (index.mNodes[i] == node) + { + found = true; + index.mSize -= 1; + if (i + 1 != index.mSize) + { + index.mNodes[i] = index.mNodes[i + 1]; + } + } + } + } + + if (found) + { + err = SaveIndex(index); + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Unable to save session resumption index: %" CHIP_ERROR_FORMAT, err.Format()); + } + } + else + { + ChipLogError(SecureChannel, + "Unable to find session resumption state for node in index" ChipLogFormatX64 ": %" CHIP_ERROR_FORMAT, + ChipLogValueX64(node.GetNodeId()), err.Format()); + } + + return CHIP_NO_ERROR; +} + +} // namespace chip diff --git a/src/protocols/secure_channel/SessionResumptionStorage.h b/src/protocols/secure_channel/SessionResumptionStorage.h new file mode 100644 index 00000000000000..f42fb6976385d6 --- /dev/null +++ b/src/protocols/secure_channel/SessionResumptionStorage.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP CASE Session object that provides + * APIs for constructing a secure session using a certificate from the device's + * operational credentials. + */ + +#pragma once + +#include +#include +#include + +namespace chip { + +/** + * @brief Stores assets for sessoin resumption. The resumption data are indexed by 2 indexes: ScopedNodeId and . The index of ScopedNodeId is used when initiating a CASE session, it will look up the storage and check whether it + * is able to resume a previous session. The index of ResumptionId is used when receiving a Sigma1 with ResumptionId. + * + * The implementation saves 2 maps: + * * => + * * => + */ +class SessionResumptionStorage +{ +public: + static constexpr size_t kResumptionIdSize = 16; + using ResumptionIdStorage = std::array; + using ResumptionIdView = FixedSpan; + using ConstResumptionIdView = FixedSpan; + + struct SessionIndex + { + size_t mSize; + ScopedNodeId mNodes[CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE]; + }; + + virtual ~SessionResumptionStorage() {} + + CHIP_ERROR FindByScopedNodeId(const ScopedNodeId & node, ResumptionIdStorage & resumptionId, + Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs); + CHIP_ERROR FindByResumptionId(ConstResumptionIdView resumptionId, ScopedNodeId & node, + Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs); + CHIP_ERROR FindNodeByResumptionId(ConstResumptionIdView resumptionId, ScopedNodeId & node); + CHIP_ERROR Save(const ScopedNodeId & node, ConstResumptionIdView resumptionId, + const Crypto::P256ECDHDerivedSecret & sharedSecret, const CATValues & peerCATs); + CHIP_ERROR Delete(const ScopedNodeId & node); + +protected: + CHIP_ERROR virtual SaveIndex(const SessionIndex & index) = 0; + CHIP_ERROR virtual LoadIndex(SessionIndex & index) = 0; + + CHIP_ERROR virtual SaveLink(ConstResumptionIdView resumptionId, const ScopedNodeId & node) = 0; + CHIP_ERROR virtual LoadLink(ConstResumptionIdView resumptionId, ScopedNodeId & node) = 0; + CHIP_ERROR virtual DeleteLink(ConstResumptionIdView resumptionId) = 0; + + CHIP_ERROR virtual SaveState(const ScopedNodeId & node, ConstResumptionIdView resumptionId, + const Crypto::P256ECDHDerivedSecret & sharedSecret, const CATValues & peerCATs) = 0; + CHIP_ERROR virtual LoadState(const ScopedNodeId & node, ResumptionIdStorage & resumptionId, + Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs) = 0; + CHIP_ERROR virtual DeleteState(const ScopedNodeId & node) = 0; +}; + +} // namespace chip diff --git a/src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp b/src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp new file mode 100644 index 00000000000000..033a53675bda0e --- /dev/null +++ b/src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP CASE Session object that provides + * APIs for constructing a secure session using a certificate from the device's + * operational credentials. + */ + +#include + +#include +#include + +namespace chip { + +constexpr TLV::Tag SimpleSessionResumptionStorage::kIndexContentTag; +constexpr TLV::Tag SimpleSessionResumptionStorage::kFabricIndexTag; +constexpr TLV::Tag SimpleSessionResumptionStorage::kPeerNodeIdTag; +constexpr TLV::Tag SimpleSessionResumptionStorage::kResumptionIdTag; +constexpr TLV::Tag SimpleSessionResumptionStorage::kSharedSecretTag; +constexpr TLV::Tag SimpleSessionResumptionStorage::kCATTag; + +const char * SimpleSessionResumptionStorage::StorageKey(DefaultStorageKeyAllocator & keyAlloc, const ScopedNodeId & node) +{ + return keyAlloc.FabricSession(node.GetFabricIndex(), node.GetNodeId()); +} + +const char * SimpleSessionResumptionStorage::StorageKey(DefaultStorageKeyAllocator & keyAlloc, ConstResumptionIdView resumptionId) +{ + char resumptionIdBase64[BASE64_ENCODED_LEN(resumptionId.size()) + 1]; + auto len = Base64Encode(resumptionId.data(), resumptionId.size(), resumptionIdBase64); + resumptionIdBase64[len] = '\0'; + return keyAlloc.SessionResumption(resumptionIdBase64); +} + +CHIP_ERROR SimpleSessionResumptionStorage::SaveIndex(const SessionIndex & index) +{ + uint8_t buf[MaxIndexSize()]; + TLV::TLVWriter writer; + writer.Init(buf); + + TLV::TLVType arrayType; + ReturnErrorOnFailure(writer.StartContainer(kIndexContentTag, TLV::kTLVType_Array, arrayType)); + + for (size_t i = index.mSize; i < index.mSize; ++i) + { + TLV::TLVType innerType; + ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, innerType)); + ReturnErrorOnFailure(writer.Put(kFabricIndexTag, index.mNodes[i].GetFabricIndex())); + ReturnErrorOnFailure(writer.Put(kPeerNodeIdTag, index.mNodes[i].GetNodeId())); + ReturnErrorOnFailure(writer.EndContainer(innerType)); + } + + ReturnErrorOnFailure(writer.EndContainer(arrayType)); + + const auto len = writer.GetLengthWritten(); + VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_BUFFER_TOO_SMALL); + + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncSetKeyValue(keyAlloc.SessionResumptionIndex(), buf, static_cast(len))); + + return CHIP_NO_ERROR; +} + +CHIP_ERROR SimpleSessionResumptionStorage::LoadIndex(SessionIndex & index) +{ + uint8_t buf[MaxIndexSize()]; + uint16_t len = static_cast(MaxStateSize()); + + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncGetKeyValue(keyAlloc.SessionResumptionIndex(), buf, len)); + + TLV::ContiguousBufferTLVReader reader; + reader.Init(buf, len); + + ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Array, kIndexContentTag)); + TLV::TLVType arrayType; + ReturnErrorOnFailure(reader.EnterContainer(arrayType)); + + size_t count = 0; + CHIP_ERROR err; + while ((err = reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())) == CHIP_NO_ERROR) + { + if (count >= ArraySize(index.mNodes)) + { + return CHIP_ERROR_NO_MEMORY; + } + + TLV::TLVType containerType; + ReturnErrorOnFailure(reader.EnterContainer(containerType)); + + FabricIndex fabricIndex; + ReturnErrorOnFailure(reader.Next(kFabricIndexTag)); + ReturnErrorOnFailure(reader.Get(fabricIndex)); + + NodeId peerNodeId; + ReturnErrorOnFailure(reader.Next(kPeerNodeIdTag)); + ReturnErrorOnFailure(reader.Get(peerNodeId)); + + index.mNodes[count++] = ScopedNodeId(peerNodeId, fabricIndex); + + ReturnErrorOnFailure(reader.ExitContainer(containerType)); + } + + if (err != CHIP_END_OF_TLV) + { + return err; + } + + ReturnErrorOnFailure(reader.ExitContainer(arrayType)); + ReturnErrorOnFailure(reader.VerifyEndOfContainer()); + + index.mSize = count; + + return CHIP_NO_ERROR; +} + +CHIP_ERROR SimpleSessionResumptionStorage::SaveLink(ConstResumptionIdView resumptionId, const ScopedNodeId & node) +{ + // Save a link from resumptionId to node, in key: /f//r/ + uint8_t buf[MaxScopedNodeIdSize()]; + + TLV::TLVWriter writer; + writer.Init(buf); + + TLV::TLVType outerType; + ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerType)); + ReturnErrorOnFailure(writer.Put(kFabricIndexTag, node.GetFabricIndex())); + ReturnErrorOnFailure(writer.Put(kPeerNodeIdTag, node.GetNodeId())); + ReturnErrorOnFailure(writer.EndContainer(outerType)); + + const auto len = writer.GetLengthWritten(); + VerifyOrDie(CanCastTo(len)); + + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, resumptionId), buf, static_cast(len))); + return CHIP_NO_ERROR; +} + +CHIP_ERROR SimpleSessionResumptionStorage::LoadLink(ConstResumptionIdView resumptionId, ScopedNodeId & node) +{ + uint8_t buf[MaxScopedNodeIdSize()]; + uint16_t len = static_cast(MaxStateSize()); + + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, resumptionId), buf, len)); + + TLV::ContiguousBufferTLVReader reader; + reader.Init(buf, len); + + ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); + TLV::TLVType containerType; + ReturnErrorOnFailure(reader.EnterContainer(containerType)); + + FabricIndex fabricIndex; + ReturnErrorOnFailure(reader.Next(kFabricIndexTag)); + ReturnErrorOnFailure(reader.Get(fabricIndex)); + + NodeId peerNodeId; + ReturnErrorOnFailure(reader.Next(kPeerNodeIdTag)); + ReturnErrorOnFailure(reader.Get(peerNodeId)); + + ReturnErrorOnFailure(reader.ExitContainer(containerType)); + ReturnErrorOnFailure(reader.VerifyEndOfContainer()); + + node = ScopedNodeId(peerNodeId, fabricIndex); + + return CHIP_NO_ERROR; +} + +CHIP_ERROR SimpleSessionResumptionStorage::DeleteLink(ConstResumptionIdView resumptionId) +{ + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncDeleteKeyValue(StorageKey(keyAlloc, resumptionId))); + return CHIP_NO_ERROR; +} + +CHIP_ERROR SimpleSessionResumptionStorage::SaveState(const ScopedNodeId & node, ConstResumptionIdView resumptionId, + const Crypto::P256ECDHDerivedSecret & sharedSecret, const CATValues & peerCATs) +{ + // Save session state into key: /f//s/ + uint8_t buf[MaxStateSize()]; + + TLV::TLVWriter writer; + writer.Init(buf); + + TLV::TLVType outerType; + ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerType)); + + ReturnErrorOnFailure(writer.Put(kResumptionIdTag, resumptionId)); + + ReturnErrorOnFailure(writer.Put(kSharedSecretTag, ByteSpan(sharedSecret.ConstBytes(), sharedSecret.Length()))); + + CATValues::Serialized cat; + peerCATs.Serialize(cat); + ReturnErrorOnFailure(writer.Put(kCATTag, ByteSpan(cat))); + + ReturnErrorOnFailure(writer.EndContainer(outerType)); + + const auto len = writer.GetLengthWritten(); + VerifyOrDie(CanCastTo(len)); + + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, node), buf, static_cast(len))); + return CHIP_NO_ERROR; +} + +CHIP_ERROR SimpleSessionResumptionStorage::LoadState(const ScopedNodeId & node, ResumptionIdStorage & resumptionId, + Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs) +{ + uint8_t buf[MaxStateSize()]; + uint16_t len = static_cast(MaxStateSize()); + + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, node), buf, len)); + + TLV::ContiguousBufferTLVReader reader; + reader.Init(buf, len); + + ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); + TLV::TLVType containerType; + ReturnErrorOnFailure(reader.EnterContainer(containerType)); + + ByteSpan resumptionIdSpan; + ReturnErrorOnFailure(reader.Next(kResumptionIdTag)); + ReturnErrorOnFailure(reader.Get(resumptionIdSpan)); + std::copy(resumptionIdSpan.begin(), resumptionIdSpan.end(), resumptionId.begin()); + + ByteSpan sharedSecretSpan; + ReturnErrorOnFailure(reader.Next(kSharedSecretTag)); + ReturnErrorOnFailure(reader.Get(sharedSecretSpan)); + VerifyOrReturnError(sharedSecretSpan.size() <= sharedSecret.Capacity(), CHIP_ERROR_BUFFER_TOO_SMALL); + ::memcpy(sharedSecret.Bytes(), sharedSecretSpan.data(), sharedSecretSpan.size()); + sharedSecret.SetLength(sharedSecretSpan.size()); + + ByteSpan catSpan; + ReturnErrorOnFailure(reader.Next(kCATTag)); + ReturnErrorOnFailure(reader.Get(catSpan)); + CATValues::Serialized cat; + VerifyOrReturnError(sizeof(cat) == catSpan.size(), CHIP_ERROR_INVALID_TLV_ELEMENT); + ::memcpy(cat, catSpan.data(), catSpan.size()); + peerCATs.Deserialize(cat); + + ReturnErrorOnFailure(reader.ExitContainer(containerType)); + ReturnErrorOnFailure(reader.VerifyEndOfContainer()); + + return CHIP_NO_ERROR; +} + +CHIP_ERROR SimpleSessionResumptionStorage::DeleteState(const ScopedNodeId & node) +{ + DefaultStorageKeyAllocator keyAlloc; + ReturnErrorOnFailure(mStorage->SyncDeleteKeyValue(StorageKey(keyAlloc, node))); + return CHIP_NO_ERROR; +} + +} // namespace chip diff --git a/src/protocols/secure_channel/SimpleSessionResumptionStorage.h b/src/protocols/secure_channel/SimpleSessionResumptionStorage.h new file mode 100644 index 00000000000000..68daeeb1206336 --- /dev/null +++ b/src/protocols/secure_channel/SimpleSessionResumptionStorage.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP CASE Session object that provides + * APIs for constructing a secure session using a certificate from the device's + * operational credentials. + */ + +#pragma once + +#include +#include +#include + +namespace chip { + +/** + * An example SessionResumptionStorage using PersistentStorageDelegate as it backend. + */ +class SimpleSessionResumptionStorage : public SessionResumptionStorage +{ +public: + CHIP_ERROR Init(PersistentStorageDelegate * storage) + { + VerifyOrReturnError(storage != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + mStorage = storage; + return CHIP_NO_ERROR; + } + + CHIP_ERROR SaveIndex(const SessionIndex & index) override; + CHIP_ERROR LoadIndex(SessionIndex & index) override; + + CHIP_ERROR SaveLink(ConstResumptionIdView resumptionId, const ScopedNodeId & node) override; + CHIP_ERROR LoadLink(ConstResumptionIdView resumptionId, ScopedNodeId & node) override; + CHIP_ERROR DeleteLink(ConstResumptionIdView resumptionId) override; + + CHIP_ERROR SaveState(const ScopedNodeId & node, ConstResumptionIdView resumptionId, + const Crypto::P256ECDHDerivedSecret & sharedSecret, const CATValues & peerCATs) override; + CHIP_ERROR LoadState(const ScopedNodeId & node, ResumptionIdStorage & resumptionId, + Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs) override; + CHIP_ERROR DeleteState(const ScopedNodeId & node) override; + +private: + static const char * StorageKey(DefaultStorageKeyAllocator & keyAlloc, const ScopedNodeId & node); + static const char * StorageKey(DefaultStorageKeyAllocator & keyAlloc, ConstResumptionIdView resumptionId); + + static constexpr size_t MaxScopedNodeIdSize() { return TLV::EstimateStructOverhead(sizeof(NodeId), sizeof(FabricIndex)); } + + static constexpr size_t MaxIndexSize() + { + // The max size of the list is (1 byte control + bytes for actual value) times max number of list items, plus one byte for + // the list terminator. + return TLV::EstimateStructOverhead((1 + MaxScopedNodeIdSize()) * CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + 1); + } + + static constexpr size_t MaxStateSize() + { + return TLV::EstimateStructOverhead(kResumptionIdSize, Crypto::P256ECDHDerivedSecret::Capacity(), + CATValues::kSerializedLength); + } + + static constexpr TLV::Tag kIndexContentTag = TLV::ContextTag(1); + static constexpr TLV::Tag kFabricIndexTag = TLV::ContextTag(2); + static constexpr TLV::Tag kPeerNodeIdTag = TLV::ContextTag(3); + static constexpr TLV::Tag kResumptionIdTag = TLV::ContextTag(4); + static constexpr TLV::Tag kSharedSecretTag = TLV::ContextTag(5); + static constexpr TLV::Tag kCATTag = TLV::ContextTag(6); + + PersistentStorageDelegate * mStorage; +}; + +} // namespace chip diff --git a/src/protocols/secure_channel/tests/BUILD.gn b/src/protocols/secure_channel/tests/BUILD.gn index 070a757fbdcf97..41a75512571fa7 100644 --- a/src/protocols/secure_channel/tests/BUILD.gn +++ b/src/protocols/secure_channel/tests/BUILD.gn @@ -10,11 +10,11 @@ chip_test_suite("tests") { test_sources = [ "TestCASESession.cpp", - "TestCASESessionCache.cpp", # TODO - Fix Message Counter Sync to use group key # "TestMessageCounterManager.cpp", "TestPASESession.cpp", + "TestSimpleSessionResumptionStorage.cpp", "TestStatusReport.cpp", ] diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 418591d2a66e7b..65725b9eaf5562 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -187,10 +187,11 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kUndefinedCATs, sizeof(CATValues)) == 0); pairing.SetGroupDataProvider(&gDeviceGroupDataProvider); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(sessionManager, nullptr, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); NL_TEST_ASSERT(inSuite, - pairing.ListenForSessionEstablishment(sessionManager, nullptr, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(sessionManager, &fabrics, &delegate) == CHIP_NO_ERROR); + pairing.ListenForSessionEstablishment(sessionManager, nullptr, nullptr, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT( + inSuite, pairing.ListenForSessionEstablishment(sessionManager, nullptr, nullptr, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(sessionManager, &fabrics, nullptr, &delegate) == CHIP_NO_ERROR); } void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) @@ -210,17 +211,17 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), nullptr, Node01_01, - nullptr, nullptr) != CHIP_NO_ERROR); + nullptr, nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - nullptr, nullptr) != CHIP_NO_ERROR); + nullptr, nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - context, &delegate) == CHIP_NO_ERROR); + context, nullptr, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -241,7 +242,7 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, pairing1.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - context1, &delegate) == CHIP_ERROR_BAD_REQUEST); + context1, nullptr, &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); gLoopback.mMessageSendError = CHIP_NO_ERROR; @@ -255,8 +256,6 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegateAccessory; CASESession pairingAccessory; - CASESessionCachable serializableCommissioner; - CASESessionCachable serializableAccessory; SessionManager sessionManager; gLoopback.mSentMessageCount = 0; @@ -272,22 +271,17 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte pairingAccessory.SetGroupDataProvider(&gDeviceGroupDataProvider); NL_TEST_ASSERT(inSuite, - pairingAccessory.ListenForSessionEstablishment(sessionManager, &gDeviceFabrics, &delegateAccessory) == + pairingAccessory.ListenForSessionEstablishment(sessionManager, &gDeviceFabrics, nullptr, &delegateAccessory) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, pairingCommissioner.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, - Node01_01, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + Node01_01, contextCommissioner, nullptr, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); - - NL_TEST_ASSERT(inSuite, pairingCommissioner.ToCachable(serializableCommissioner) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.ToCachable(serializableAccessory) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, - memcmp(serializableCommissioner.mSharedSecret, serializableAccessory.mSharedSecret, - serializableCommissioner.mSharedSecretLen) == 0); } void CASE_SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) @@ -320,7 +314,7 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte #if CONFIG_NETWORK_LAYER_BLE nullptr, #endif - &ctx.GetSecureSessionManager(), &gDeviceFabrics, + &ctx.GetSecureSessionManager(), &gDeviceFabrics, nullptr, &gDeviceGroupDataProvider) == CHIP_NO_ERROR); ExchangeContext * contextCommissioner = ctx.NewUnauthenticatedExchangeToBob(pairingCommissioner); @@ -330,7 +324,8 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, pairingCommissioner->EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, - Node01_01, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + Node01_01, contextCommissioner, nullptr, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); @@ -342,7 +337,8 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, pairingCommissioner1->EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, - Node01_01, contextCommissioner1, &delegateCommissioner) == CHIP_NO_ERROR); + Node01_01, contextCommissioner1, nullptr, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); chip::Platform::Delete(pairingCommissioner); diff --git a/src/protocols/secure_channel/tests/TestCASESessionCache.cpp b/src/protocols/secure_channel/tests/TestCASESessionCache.cpp deleted file mode 100644 index 1ed60e8fcfb4d1..00000000000000 --- a/src/protocols/secure_channel/tests/TestCASESessionCache.cpp +++ /dev/null @@ -1,233 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file implements unit tests for the CASESession implementation. - */ - -#include -#include - -#include -#include -#include -#include -#include -#include - -using namespace chip; - -namespace { -NodeId sTest_PeerId = 0xEDEDEDED00010001; - -uint8_t sTest_SharedSecret[] = { - 0x7d, 0x73, 0x5b, 0xef, 0xe9, 0x16, 0xa1, 0xc0, 0xca, 0x02, 0xf8, 0xca, 0x98, 0x81, 0xe4, 0x26, - 0x63, 0xaa, 0xaf, 0x9a, 0xb9, 0xc4, 0x33, 0xb2, 0x89, 0xbe, 0x26, 0x70, 0x10, 0x75, 0x74, 0x10, -}; - -uint8_t sTest_ResumptionId[kCASEResumptionIDSize] = { 0 }; - -} // namespace - -class CASESessionTest : public CASESession -{ -public: - void createCASESessionTestCachable(uint8_t i) - { - uint16_t sharedSecretLen = sizeof(sTest_SharedSecret); - sTest_SharedSecret[sharedSecretLen - 1] = static_cast(sTest_SharedSecret[sharedSecretLen - 1] + i); - uint64_t timestamp = static_cast(4000 + i * 1000); - sTest_ResumptionId[kCASEResumptionIDSize - 1] = static_cast(sTest_ResumptionId[kCASEResumptionIDSize - 1] + i); - - mCASESessionCachableArray[i].mSharedSecretLen = sharedSecretLen; - memcpy(mCASESessionCachableArray[i].mSharedSecret, sTest_SharedSecret, sharedSecretLen); - mCASESessionCachableArray[i].mPeerNodeId = static_cast(sTest_PeerId + i); - mCASESessionCachableArray[i].mPeerCATs.values[0] = (uint32_t) i; - memcpy(mCASESessionCachableArray[i].mResumptionId, sTest_ResumptionId, kCASEResumptionIDSize); - mCASESessionCachableArray[i].mLocalFabricIndex = 0; - mCASESessionCachableArray[i].mSessionSetupTimeStamp = timestamp; - } - - bool isEqual(int index, CASESessionCachable cachableSession) - { - return (cachableSession.mSharedSecretLen == mCASESessionCachableArray[index].mSharedSecretLen) && - ((ByteSpan(cachableSession.mSharedSecret)).data_equal(ByteSpan(mCASESessionCachableArray[index].mSharedSecret))) && - (cachableSession.mPeerNodeId == mCASESessionCachableArray[index].mPeerNodeId) && - cachableSession.mPeerCATs.values[0] == mCASESessionCachableArray[index].mPeerCATs.values[0] && - ((ResumptionID(cachableSession.mResumptionId)) - .data_equal(ResumptionID(mCASESessionCachableArray[index].mResumptionId))) && - (cachableSession.mLocalFabricIndex == mCASESessionCachableArray[index].mLocalFabricIndex) && - (cachableSession.mSessionSetupTimeStamp == mCASESessionCachableArray[index].mSessionSetupTimeStamp); - } - - void InitializeCASESessionCachableArray() - { - for (size_t j = 0; j < kCASEResumptionIDSize; j++) - { - sTest_ResumptionId[j] = 0x01; - } - for (uint8_t i = 0; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE; i++) - { - createCASESessionTestCachable(i); - } - } - - CASESessionCachable mCASESessionCachableArray[CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + 1] = { { 0 } }; - CASESessionCache mCASESessionCache; -}; - -CASESessionTest mCASESessionTest; - -static void CASESessionCache_Create_Test(nlTestSuite * inSuite, void * inContext) -{ - mCASESessionTest.InitializeCASESessionCachableArray(); -} - -static void CASESessionCache_Add_Test(nlTestSuite * inSuite, void * inContext) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - for (uint8_t i = 0; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE; i++) - { - CASESession session; - err = mCASESessionTest.mCASESessionCache.Add(mCASESessionTest.mCASESessionCachableArray[i]); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - } -} - -static void CASESessionCache_Get_Test(nlTestSuite * inSuite, void * inContext) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - for (uint8_t i = 0; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE; i++) - { - CASESessionCachable outCachableSession; - err = mCASESessionTest.mCASESessionCache.Get(ResumptionID(mCASESessionTest.mCASESessionCachableArray[i].mResumptionId), - outCachableSession); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, true == mCASESessionTest.isEqual(i, outCachableSession)); - } -} - -static void CASESessionCache_Add_When_Full_Test(nlTestSuite * inSuite, void * inContext) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - mCASESessionTest.createCASESessionTestCachable(CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE); - err = mCASESessionTest.mCASESessionCache.Add( - mCASESessionTest.mCASESessionCachableArray[CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE]); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - // Check if the entry with lowest timestamp has been removed - CASESessionCachable outCachableSession; - err = mCASESessionTest.mCASESessionCache.Get(ResumptionID(mCASESessionTest.mCASESessionCachableArray[0].mResumptionId), - outCachableSession); - NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND); - - // Check if the new entry has been added. - err = mCASESessionTest.mCASESessionCache.Get( - ResumptionID(mCASESessionTest.mCASESessionCachableArray[CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE].mResumptionId), - outCachableSession); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, true == mCASESessionTest.isEqual(CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE, outCachableSession)); -} - -static void CASESessionCache_Remove_Test(nlTestSuite * inSuite, void * inContext) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - for (uint8_t i = 1; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + 1; i++) - { - err = mCASESessionTest.mCASESessionCache.Remove(ResumptionID(mCASESessionTest.mCASESessionCachableArray[i].mResumptionId)); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - CASESessionCachable outCachableSession; - err = mCASESessionTest.mCASESessionCache.Get(ResumptionID(mCASESessionTest.mCASESessionCachableArray[i].mResumptionId), - outCachableSession); - NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND); - } -} - -// Test Suite - -/** - * Test Suite that lists all the test functions. - */ -// clang-format off -static const nlTest sTests[] = -{ - NL_TEST_DEF("Create", CASESessionCache_Create_Test), - NL_TEST_DEF("Add", CASESessionCache_Add_Test), - NL_TEST_DEF("Get", CASESessionCache_Get_Test), - NL_TEST_DEF("AddWhenFull", CASESessionCache_Add_When_Full_Test), - NL_TEST_DEF("Remove", CASESessionCache_Remove_Test), - - NL_TEST_SENTINEL() -}; -// clang-format on - -int CASESessionCache_Test_Setup(void * inContext); -int CASESessionCache_Test_Teardown(void * inContext); - -// clang-format off -static nlTestSuite sSuite = -{ - "Test-CHIP-SecurePairing-CASECache", - &sTests[0], - CASESessionCache_Test_Setup, - CASESessionCache_Test_Teardown, -}; -// clang-format on - -namespace { -/* - * Set up the test suite. - */ -CHIP_ERROR CASETestCacheSetup(void * inContext) -{ - ReturnErrorOnFailure(chip::Platform::MemoryInit()); - return CHIP_NO_ERROR; -} -} // anonymous namespace - -/** - * Set up the test suite. - */ -int CASESessionCache_Test_Setup(void * inContext) -{ - return CASETestCacheSetup(inContext) == CHIP_NO_ERROR ? SUCCESS : FAILURE; -} - -/** - * Tear down the test suite. - */ -int CASESessionCache_Test_Teardown(void * inContext) -{ - chip::Platform::MemoryShutdown(); - return SUCCESS; -} - -/** - * Main - */ -int TestCASESessionCache() -{ - // Run test suit against one context - nlTestRunner(&sSuite, nullptr); - - return (nlTestRunnerStats(&sSuite)); -} - -CHIP_REGISTER_TEST_SUITE(TestCASESessionCache) diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 838b24804cdf63..4cc6ffc1cf5a7a 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -123,9 +123,6 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) SessionManager sessionManager; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kPASE); - CATValues peerCATs; - peerCATs = pairing.GetPeerCATs(); - NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kUndefinedCATs, sizeof(CATValues)) == 0); gLoopback.Reset(); diff --git a/src/protocols/secure_channel/tests/TestSimpleSessionResumptionStorage.cpp b/src/protocols/secure_channel/tests/TestSimpleSessionResumptionStorage.cpp new file mode 100644 index 00000000000000..f967308e73692c --- /dev/null +++ b/src/protocols/secure_channel/tests/TestSimpleSessionResumptionStorage.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +constexpr chip::FabricIndex fabric1 = 10; +constexpr chip::NodeId node1 = 12344321; +constexpr chip::FabricIndex fabric2 = 14; +constexpr chip::NodeId node2 = 11223344; + +void TestLink(nlTestSuite * inSuite, void * inContext) +{ + chip::TestPersistentStorageDelegate storage; + chip::SimpleSessionResumptionStorage sessionStorage; + sessionStorage.Init(&storage); + + chip::SimpleSessionResumptionStorage::ResumptionIdStorage resumptionId; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == chip::Crypto::DRBG_get_bytes(resumptionId.data(), resumptionId.size())); + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.SaveLink(resumptionId, chip::ScopedNodeId(node1, fabric1))); + + chip::ScopedNodeId node; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.LoadLink(resumptionId, node)); + NL_TEST_ASSERT(inSuite, node == chip::ScopedNodeId(node1, fabric1)); + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.DeleteLink(resumptionId)); + + NL_TEST_ASSERT(inSuite, CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND == sessionStorage.LoadLink(resumptionId, node)); +} + +void TestState(nlTestSuite * inSuite, void * inContext) +{ + chip::TestPersistentStorageDelegate storage; + chip::SimpleSessionResumptionStorage sessionStorage; + sessionStorage.Init(&storage); + + chip::ScopedNodeId node(node1, fabric1); + + chip::SimpleSessionResumptionStorage::ResumptionIdStorage resumptionId; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == chip::Crypto::DRBG_get_bytes(resumptionId.data(), resumptionId.size())); + + chip::Crypto::P256ECDHDerivedSecret sharedSecret; + sharedSecret.SetLength(sharedSecret.Capacity()); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == chip::Crypto::DRBG_get_bytes(sharedSecret.Bytes(), sharedSecret.Length())); + + chip::CATValues peerCATs; + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.SaveState(node, resumptionId, sharedSecret, peerCATs)); + + chip::SimpleSessionResumptionStorage::ResumptionIdStorage resumptionId2; + chip::Crypto::P256ECDHDerivedSecret sharedSecret2; + chip::CATValues peerCATs2; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.LoadState(node, resumptionId2, sharedSecret2, peerCATs2)); + NL_TEST_ASSERT(inSuite, resumptionId == resumptionId2); + NL_TEST_ASSERT(inSuite, memcmp(sharedSecret.Bytes(), sharedSecret2.Bytes(), sharedSecret.Length()) == 0); + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.DeleteState(node)); + + NL_TEST_ASSERT(inSuite, + CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND == + sessionStorage.LoadState(node, resumptionId2, sharedSecret2, peerCATs2)); +} + +void TestIndex(nlTestSuite * inSuite, void * inContext) +{ + chip::TestPersistentStorageDelegate storage; + chip::SimpleSessionResumptionStorage sessionStorage; + sessionStorage.Init(&storage); + + chip::ScopedNodeId node(node1, fabric1); + + chip::SessionResumptionStorage::SessionIndex index1; + index1.mSize = 0; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.SaveIndex(index1)); + chip::SessionResumptionStorage::SessionIndex index1o; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.LoadIndex(index1o)); + NL_TEST_ASSERT(inSuite, index1o.mSize == 0); + + chip::SessionResumptionStorage::SessionIndex index2; + index2.mSize = 2; + index2.mNodes[0] = chip::ScopedNodeId(node1, fabric1); + index2.mNodes[1] = chip::ScopedNodeId(node2, fabric2); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.SaveIndex(index2)); + chip::SessionResumptionStorage::SessionIndex index2o; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == sessionStorage.LoadIndex(index2o)); + NL_TEST_ASSERT(inSuite, index2o.mSize == 2); + NL_TEST_ASSERT(inSuite, index2o.mNodes[0] == chip::ScopedNodeId(node1, fabric1)); + NL_TEST_ASSERT(inSuite, index2o.mNodes[1] == chip::ScopedNodeId(node2, fabric2)); +} + +// Test Suite + +/** + * Test Suite that lists all the test functions. + */ +// clang-format off +static const nlTest sTests[] = +{ + NL_TEST_DEF("TestLink", TestLink), + NL_TEST_DEF("TestState", TestState), + + NL_TEST_SENTINEL() +}; +// clang-format on + +// clang-format off +static nlTestSuite sSuite = +{ + "Test-CHIP-SimpleSessionResumptionStorage", + &sTests[0], + nullptr, + nullptr, +}; +// clang-format on + +/** + * Main + */ +int TestSimpleSessionResumptionStorage() +{ + // Run test suit against one context + nlTestRunner(&sSuite, nullptr); + + return (nlTestRunnerStats(&sSuite)); +} + +CHIP_REGISTER_TEST_SUITE(TestSimpleSessionResumptionStorage) diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index 2b256857850248..fee363f0ec714b 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -124,7 +124,6 @@ class DLL_EXPORT PairingSession CHIP_ERROR AllocateSecureSession(SessionManager & sessionManager, uint16_t sessionId); void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; } - void SetPeerCATs(CATValues peerCATs) { mPeerCATs = peerCATs; } void SetPeerSessionId(uint16_t id) { mPeerSessionId.SetValue(id); } void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } virtual void OnSuccessStatusReport() {} @@ -203,9 +202,12 @@ class DLL_EXPORT PairingSession private: const Transport::SecureSession::Type mSecureSessionType; + +protected: NodeId mPeerNodeId = kUndefinedNodeId; CATValues mPeerCATs; +private: SessionHolder mSecureSessionHolder; // TODO: decouple peer address into transport, such that pairing session do not need to handle peer address