diff --git a/srtcore/group.cpp b/srtcore/group.cpp index d4598d7c1..ee28029ae 100644 --- a/srtcore/group.cpp +++ b/srtcore/group.cpp @@ -758,12 +758,15 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen) w_optlen = sizeof(uint32_t); return; + case SRTO_KMSTATE: + *(uint32_t*)pw_optval = getGroupEncryptionState(); + w_optlen = sizeof(uint32_t); + return; + // Write-only options for security reasons or // options that refer to a socket state, that // makes no sense for a group. case SRTO_PASSPHRASE: - case SRTO_KMSTATE: - case SRTO_PBKEYLEN: case SRTO_KMPREANNOUNCE: case SRTO_KMREFRESHRATE: case SRTO_BINDTODEVICE: @@ -775,6 +778,19 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen) default:; // pass on } + bool is_set_on_socket = false; + { + ScopedLock cg(m_GroupLock); + gli_t gi = m_Group.begin(); + if (gi != m_Group.end()) + { + // Return the value from the first member socket, if any is present + // Note: Will throw exception if the request is wrong. + gi->ps->core().getOpt(optname, (pw_optval), (w_optlen)); + is_set_on_socket = true; + } + } + // Check if the option is in the storage, which means that // it was modified on the group. @@ -783,12 +799,18 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen) if (i == m_config.end()) { + // Already written to the target variable. + if (is_set_on_socket) + return; + // Not found, see the defaults if (!getOptDefault(optname, (pw_optval), (w_optlen))) throw CUDTException(MJ_NOTSUP, MN_INVAL, 0); return; } + // NOTE: even if is_set_on_socket, if it was also found in the group + // settings, overwrite with the value from the group. // Found, return the value from the storage. // Check the size first. @@ -799,6 +821,52 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen) memcpy((pw_optval), &i->value[0], i->value.size()); } +SRT_KM_STATE CUDTGroup::getGroupEncryptionState() +{ + multiset kmstates; + { + ScopedLock lk (m_GroupLock); + + // First check the container. If empty, return UNSECURED + if (m_Group.empty()) + return SRT_KM_S_UNSECURED; + + for (gli_t gi = m_Group.begin(); gi != m_Group.end(); ++gi) + { + CCryptoControl* cc = gi->ps->core().m_pCryptoControl.get(); + if (!cc) + continue; + SRT_KM_STATE gst = cc->m_RcvKmState; + // A fix to NOSECRET is because this is the state when agent has set + // no password, but peer did, and ENFORCEDENCRYPTION=false allowed + // this connection to be established. UNSECURED can't be taken in this + // case because this would suggest that BOTH are unsecured, that is, + // we have established an unsecured connection (which ain't true). + if (gst == SRT_KM_S_UNSECURED && cc->m_SndKmState == SRT_KM_S_NOSECRET) + gst = SRT_KM_S_NOSECRET; + kmstates.insert(gst); + } + } + + // Criteria are: + // 1. UNSECURED, if no member sockets, or at least one UNSECURED found. + // 2. SECURED, if at least one SECURED found (cut off the previous criteria). + // 3. BADSECRET otherwise, although return NOSECRET if no BADSECRET is found. + + if (kmstates.count(SRT_KM_S_UNSECURED)) + return SRT_KM_S_UNSECURED; + + // Now we have UNSECURED ruled out. Remaining may be NOSECRET, BADSECRET or SECURED. + // NOTE: SECURING is an intermediate state for HSv4 and can't occur in groups. + if (kmstates.count(SRT_KM_S_SECURED)) + return SRT_KM_S_SECURED; + + if (kmstates.count(SRT_KM_S_BADSECRET)) + return SRT_KM_S_BADSECRET; + + return SRT_KM_S_NOSECRET; +} + SRT_SOCKSTATUS CUDTGroup::getStatus() { typedef vector > states_t; diff --git a/srtcore/group.h b/srtcore/group.h index 247423da3..56f1456f8 100644 --- a/srtcore/group.h +++ b/srtcore/group.h @@ -313,6 +313,8 @@ class CUDTGroup void send_CheckValidSockets(); + SRT_KM_STATE getGroupEncryptionState(); + public: int recv(char* buf, int len, SRT_MSGCTRL& w_mc); diff --git a/test/test_bonding.cpp b/test/test_bonding.cpp index 241d26d99..c5e362e69 100644 --- a/test/test_bonding.cpp +++ b/test/test_bonding.cpp @@ -349,6 +349,14 @@ TEST(Bonding, Options) string pass = "longenoughpassword"; // passphrase should be ok. EXPECT_NE(srt_setsockflag(grp, SRTO_PASSPHRASE, pass.c_str(), pass.size()), SRT_ERROR); + + uint32_t val = 16; + EXPECT_NE(srt_setsockflag(grp, SRTO_PBKEYLEN, &val, sizeof val), SRT_ERROR); + +#ifdef ENABLE_AEAD_API_PREVIEW + val = 1; + EXPECT_NE(srt_setsockflag(grp, SRTO_CRYPTOMODE, &val, sizeof val), SRT_ERROR); +#endif #endif int lat = 500; @@ -446,6 +454,25 @@ TEST(Bonding, Options) EXPECT_EQ(optsize, sizeof ohead); EXPECT_EQ(ohead, 12); +#if SRT_ENABLE_ENCRYPTION + + uint32_t kms = -1; + + EXPECT_NE(srt_getsockflag(grp, SRTO_KMSTATE, &kms, &optsize), SRT_ERROR); + EXPECT_EQ(optsize, sizeof kms); + EXPECT_EQ(kms, int(SRT_KM_S_SECURED)); + + EXPECT_NE(srt_getsockflag(grp, SRTO_PBKEYLEN, &kms, &optsize), SRT_ERROR); + EXPECT_EQ(optsize, sizeof kms); + EXPECT_EQ(kms, 16); + +#ifdef ENABLE_AEAD_API_PREVIEW + EXPECT_NE(srt_getsockflag(grp, SRTO_CRYPTOMODE, &kms, &optsize), SRT_ERROR); + EXPECT_EQ(optsize, sizeof kms); + EXPECT_EQ(kms, 1); +#endif +#endif + // We're done, the thread can close connection and exit { // Make sure that the thread reached the wait() call. diff --git a/test/test_crypto.cpp b/test/test_crypto.cpp index f4fa7f614..47b18dd1a 100644 --- a/test/test_crypto.cpp +++ b/test/test_crypto.cpp @@ -44,7 +44,7 @@ namespace srt m_crypt.setCryptoKeylen(cfg.iSndCryptoKeyLen); cfg.iCryptoMode = CSrtConfig::CIPHER_MODE_AES_GCM; - EXPECT_EQ(m_crypt.init(HSD_INITIATOR, cfg, true), HaiCrypt_IsAESGCM_Supported() != 0); + EXPECT_TRUE(m_crypt.init(HSD_INITIATOR, cfg, true, HaiCrypt_IsAESGCM_Supported())); const unsigned char* kmmsg = m_crypt.getKmMsg_data(0); const size_t km_len = m_crypt.getKmMsg_size(0); @@ -53,7 +53,7 @@ namespace srt std::array km_nworder; NtoHLA(km_nworder.data(), reinterpret_cast(kmmsg), km_len); - m_crypt.processSrtMsg_KMREQ(km_nworder.data(), km_len, 5, kmout, kmout_len); + m_crypt.processSrtMsg_KMREQ(km_nworder.data(), km_len, 5, SrtVersion(1, 5, 3), kmout, kmout_len); } void TearDown() override