From 1609788d68f8508c5c98184a05033a490f038562 Mon Sep 17 00:00:00 2001 From: Sektor van Skijlen Date: Fri, 16 Feb 2024 14:41:34 +0100 Subject: [PATCH] [core] Fixed the PacketFilter configuration not counting the AEAD AUTH tag (#2880). * Fixed build break due to const use of UniquePtr. --- srtcore/core.cpp | 17 ++++++++++++++++- srtcore/core.h | 1 + srtcore/fec.cpp | 3 +++ srtcore/packetfilter.cpp | 8 +++++++- srtcore/utilities.h | 2 +- 5 files changed, 28 insertions(+), 3 deletions(-) diff --git a/srtcore/core.cpp b/srtcore/core.cpp index 469d8af08..8c81156f8 100644 --- a/srtcore/core.cpp +++ b/srtcore/core.cpp @@ -5609,6 +5609,14 @@ bool srt::CUDT::prepareConnectionObjects(const CHandShake &hs, HandshakeSide hsd return true; } +int srt::CUDT::getAuthTagSize() const +{ + if (m_pCryptoControl && m_pCryptoControl->getCryptoMode() == CSrtConfig::CIPHER_MODE_AES_GCM) + return HAICRYPT_AUTHTAG_MAX; + + return 0; +} + bool srt::CUDT::prepareBuffers(CUDTException* eout) { if (m_pSndBuffer) @@ -5620,7 +5628,14 @@ bool srt::CUDT::prepareBuffers(CUDTException* eout) try { // CryptoControl has to be initialized and in case of RESPONDER the KM REQ must be processed (interpretSrtHandshake(..)) for the crypto mode to be deduced. - const int authtag = (m_pCryptoControl && m_pCryptoControl->getCryptoMode() == CSrtConfig::CIPHER_MODE_AES_GCM) ? HAICRYPT_AUTHTAG_MAX : 0; + const int authtag = getAuthTagSize(); + + SRT_ASSERT(m_iMaxSRTPayloadSize != 0); + + HLOGC(rslog.Debug, log << CONID() << "Creating buffers: snd-plsize=" << m_iMaxSRTPayloadSize + << " snd-bufsize=" << 32 + << " authtag=" << authtag); + m_pSndBuffer = new CSndBuffer(32, m_iMaxSRTPayloadSize, authtag); SRT_ASSERT(m_iPeerISN != -1); m_pRcvBuffer = new srt::CRcvBuffer(m_iPeerISN, m_config.iRcvBufSize, m_pRcvQueue->m_pUnitQueue, m_config.bMessageAPI); diff --git a/srtcore/core.h b/srtcore/core.h index ef213be60..a6b902701 100644 --- a/srtcore/core.h +++ b/srtcore/core.h @@ -494,6 +494,7 @@ class CUDT /// Allocates sender and receiver buffers and loss lists. SRT_ATR_NODISCARD SRT_ATTR_REQUIRES(m_ConnectionLock) bool prepareBuffers(CUDTException* eout); + int getAuthTagSize() const; SRT_ATR_NODISCARD SRT_ATTR_REQUIRES(m_ConnectionLock) EConnectStatus postConnect(const CPacket* response, bool rendezvous, CUDTException* eout) ATR_NOEXCEPT; diff --git a/srtcore/fec.cpp b/srtcore/fec.cpp index a41e3a33b..76a47384a 100644 --- a/srtcore/fec.cpp +++ b/srtcore/fec.cpp @@ -598,6 +598,9 @@ void FECFilterBuiltin::ClipData(Group& g, uint16_t length_net, uint8_t kflg, g.flag_clip = g.flag_clip ^ kflg; g.timestamp_clip = g.timestamp_clip ^ timestamp_hw; + HLOGC(pflog.Debug, log << "FEC CLIP: data pkt.size=" << payload_size + << " to a clip buffer size=" << payloadSize()); + // Payload goes "as is". for (size_t i = 0; i < payload_size; ++i) { diff --git a/srtcore/packetfilter.cpp b/srtcore/packetfilter.cpp index 37785f43a..9610f04e3 100644 --- a/srtcore/packetfilter.cpp +++ b/srtcore/packetfilter.cpp @@ -314,9 +314,15 @@ bool srt::PacketFilter::configure(CUDT* parent, CUnitQueue* uq, const std::strin init.socket_id = parent->socketID(); init.snd_isn = parent->sndSeqNo(); init.rcv_isn = parent->rcvSeqNo(); - init.payload_size = parent->OPT_PayloadSize(); + + // XXX This is a formula for a full "SRT payload" part that undergoes transmission, + // might be nice to have this formula as something more general. + init.payload_size = parent->OPT_PayloadSize() + parent->getAuthTagSize(); init.rcvbuf_size = parent->m_config.iRcvBufSize; + HLOGC(pflog.Debug, log << "PFILTER: @" << init.socket_id << " payload size=" + << init.payload_size << " rcvbuf size=" << init.rcvbuf_size); + // Found a filter, so call the creation function m_filter = selector->second->Create(init, m_provided, confstr); if (!m_filter) diff --git a/srtcore/utilities.h b/srtcore/utilities.h index 31e05b205..82cae300f 100644 --- a/srtcore/utilities.h +++ b/srtcore/utilities.h @@ -654,7 +654,7 @@ class UniquePtr: public std::auto_ptr bool operator==(const element_type* two) const { return get() == two; } bool operator!=(const element_type* two) const { return get() != two; } - operator bool () { return 0!= get(); } + operator bool () const { return 0!= get(); } }; // A primitive one-argument versions of Sprint and Printable