Skip to content

Commit

Permalink
[core] Protect CUnit::m_iFlag from data race
Browse files Browse the repository at this point in the history
using an atomic.
Refactored common allocation code CUnitQueue::allocateEntry(..).
  • Loading branch information
maxsharabayko committed Jul 13, 2022
1 parent b5055db commit a51ec39
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 121 deletions.
6 changes: 3 additions & 3 deletions srtcore/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5651,9 +5651,9 @@ bool srt::CUDT::prepareConnectionObjects(const CHandShake &hs, HandshakeSide hsd
m_pSndBuffer = new CSndBuffer(32, m_iMaxSRTPayloadSize);
#if ENABLE_NEW_RCVBUFFER
SRT_ASSERT(m_iISN != -1);
m_pRcvBuffer = new srt::CRcvBufferNew(m_iISN, m_config.iRcvBufSize, &(m_pRcvQueue->m_UnitQueue), m_config.bMessageAPI);
m_pRcvBuffer = new srt::CRcvBufferNew(m_iISN, m_config.iRcvBufSize, m_pRcvQueue->m_pUnitQueue, m_config.bMessageAPI);
#else
m_pRcvBuffer = new CRcvBuffer(&(m_pRcvQueue->m_UnitQueue), m_config.iRcvBufSize);
m_pRcvBuffer = new CRcvBuffer(m_pRcvQueue->m_pUnitQueue, m_config.iRcvBufSize);
#endif
// after introducing lite ACK, the sndlosslist may not be cleared in time, so it requires twice space.
m_pSndLossList = new CSndLossList(m_iFlowWindowSize * 2);
Expand Down Expand Up @@ -5943,7 +5943,7 @@ SRT_REJECT_REASON srt::CUDT::setupCC()
{
// The filter configurer is build the way that allows to quit immediately
// exit by exception, but the exception is meant for the filter only.
status = m_PacketFilter.configure(this, &(m_pRcvQueue->m_UnitQueue), m_config.sPacketFilterConfig.str());
status = m_PacketFilter.configure(this, m_pRcvQueue->m_pUnitQueue, m_config.sPacketFilterConfig.str());
}
catch (CUDTException& )
{
Expand Down
124 changes: 49 additions & 75 deletions srtcore/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,22 @@ using namespace std;
using namespace srt::sync;
using namespace srt_logging;

srt::CUnitQueue::CUnitQueue()
: m_pQEntry(NULL)
, m_pCurrQueue(NULL)
, m_pLastQueue(NULL)
, m_iSize(0)
, m_iCount(0)
, m_iMSS()
srt::CUnitQueue::CUnitQueue(int initNumUnits, int mss)
: m_iNumTaken(0)
, m_iMSS(mss)
, m_iBlockSize(initNumUnits)
{
CQEntry* tempq = allocateEntry(m_iBlockSize, m_iMSS);

if (tempq == NULL)
throw CUDTException(MJ_SYSTEMRES, MN_MEMORY);

m_pQEntry = m_pCurrQueue = m_pLastQueue = tempq;
m_pQEntry->m_pNext = m_pQEntry;

m_pAvailUnit = m_pCurrQueue->m_pUnit;

m_iSize = m_iBlockSize;
}

srt::CUnitQueue::~CUnitQueue()
Expand All @@ -93,104 +101,70 @@ srt::CUnitQueue::~CUnitQueue()
}
}

int srt::CUnitQueue::init(int size, int mss)
srt::CUnitQueue::CQEntry* srt::CUnitQueue::allocateEntry(const int iNumUnits, const int mss)
{
CQEntry* tempq = NULL;
CUnit* tempu = NULL;
char* tempb = NULL;
CUnit* tempu = NULL;
char* tempb = NULL;

try
{
tempq = new CQEntry;
tempu = new CUnit[size];
tempb = new char[size * mss];
tempu = new CUnit[iNumUnits];
tempb = new char[iNumUnits * mss];
}
catch (...)
{
delete tempq;
delete[] tempu;
delete[] tempb;

return -1;
LOGC(rslog.Error, log << "CUnitQueue: failed to allocate " << iNumUnits << " units.");
return NULL;
}

for (int i = 0; i < size; ++i)
for (int i = 0; i < iNumUnits; ++i)
{
tempu[i].m_iFlag = CUnit::FREE;
tempu[i].m_iFlag = CUnit::FREE;
tempu[i].m_Packet.m_pcData = tempb + i * mss;
}

tempq->m_pUnit = tempu;
tempq->m_pBuffer = tempb;
tempq->m_iSize = size;
tempq->m_iSize = iNumUnits;

m_pQEntry = m_pCurrQueue = m_pLastQueue = tempq;
m_pQEntry->m_pNext = m_pQEntry;

m_pAvailUnit = m_pCurrQueue->m_pUnit;

m_iSize = size;
m_iMSS = mss;

return 0;
return tempq;
}

// XXX Lots of common code with CUnitQueue:init.
// Consider merging.
int srt::CUnitQueue::increase()
int srt::CUnitQueue::increase_()
{
if (double(m_iCount) / m_iSize < 0.9)
return -1;

CQEntry* tempq = NULL;
CUnit* tempu = NULL;
char* tempb = NULL;

// all queues have the same size
const int size = m_pQEntry->m_iSize;
const int numUnits = m_iBlockSize;
HLOGC(qrlog.Debug, log << "CUnitQueue::increase: Capacity" << capacity() << " + " << numUnits << " new units, " << m_iNumTaken << " in use.");

try
{
tempq = new CQEntry;
tempu = new CUnit[size];
tempb = new char[size * m_iMSS];
}
catch (...)
{
delete tempq;
delete[] tempu;
delete[] tempb;

LOGC(rslog.Error,
log << "CUnitQueue:increase: failed to allocate " << size << " new units."
<< " Current size=" << m_iSize);
CQEntry* tempq = allocateEntry(numUnits, m_iMSS);
if (tempq == NULL)
return -1;
}

for (int i = 0; i < size; ++i)
{
tempu[i].m_iFlag = CUnit::FREE;
tempu[i].m_Packet.m_pcData = tempb + i * m_iMSS;
}
tempq->m_pUnit = tempu;
tempq->m_pBuffer = tempb;
tempq->m_iSize = size;

m_pLastQueue->m_pNext = tempq;
m_pLastQueue = tempq;
m_pLastQueue->m_pNext = m_pQEntry;

m_iSize += size;
m_iSize += numUnits;

return 0;
}

srt::CUnit* srt::CUnitQueue::getNextAvailUnit()
{
if (m_iCount * 10 > m_iSize * 9)
increase();
const int iNumUnitsTotal = capacity();
if (m_iNumTaken * 10 > iNumUnitsTotal * 9) // 90% or more are in use.
increase_();

if (m_iCount >= m_iSize)
if (m_iNumTaken >= capacity())
{
LOGC(qrlog.Error, log << "CUnitQueue: No free units to take. Capacity" << capacity() << ".");
return NULL;
}

int units_checked = 0;
do
Expand All @@ -208,27 +182,25 @@ srt::CUnit* srt::CUnitQueue::getNextAvailUnit()
m_pAvailUnit = m_pCurrQueue->m_pUnit;
} while (units_checked < m_iSize);

increase();

return NULL;
}

void srt::CUnitQueue::makeUnitFree(CUnit* unit)
{
SRT_ASSERT(unit != NULL);
SRT_ASSERT(unit->m_iFlag != CUnit::FREE);
unit->m_iFlag = CUnit::FREE;
unit->m_iFlag.store(CUnit::FREE);

--m_iCount;
--m_iNumTaken;
}

void srt::CUnitQueue::makeUnitGood(CUnit* unit)
{
++m_iCount;
++m_iNumTaken;

SRT_ASSERT(unit != NULL);
SRT_ASSERT(unit->m_iFlag == CUnit::FREE);
unit->m_iFlag = CUnit::GOOD;
unit->m_iFlag.store(CUnit::GOOD);
}

srt::CSndUList::CSndUList(sync::CTimer* pTimer)
Expand Down Expand Up @@ -1110,7 +1082,7 @@ bool srt::CRendezvousQueue::qualifyToHandle(EReadStatus rst,
//
srt::CRcvQueue::CRcvQueue()
: m_WorkerThread()
, m_UnitQueue()
, m_pUnitQueue(NULL)
, m_pRcvUList(NULL)
, m_pHash(NULL)
, m_pChannel(NULL)
Expand Down Expand Up @@ -1140,6 +1112,7 @@ srt::CRcvQueue::~CRcvQueue()
}
releaseCond(m_BufferCond);

delete m_pUnitQueue;
delete m_pRcvUList;
delete m_pHash;
delete m_pRendezvousQueue;
Expand All @@ -1166,7 +1139,8 @@ void srt::CRcvQueue::init(int qsize, size_t payload, int version, int hsize, CCh
m_iIPversion = version;
m_szPayloadSize = payload;

m_UnitQueue.init(qsize, (int)payload);
SRT_ASSERT(m_pUnitQueue == NULL);
m_pUnitQueue = new CUnitQueue(qsize, (int)payload);

m_pHash = new CHash;
m_pHash->init(hsize);
Expand Down Expand Up @@ -1345,7 +1319,7 @@ srt::EReadStatus srt::CRcvQueue::worker_RetrieveUnit(int32_t& w_id, CUnit*& w_un
}
}
// find next available slot for incoming packet
w_unit = m_UnitQueue.getNextAvailUnit();
w_unit = m_pUnitQueue->getNextAvailUnit();
if (!w_unit)
{
// no space, skip this packet
Expand Down
73 changes: 40 additions & 33 deletions srtcore/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,29 @@ struct CUnit
PASSACK = 2,
DROPPED = 3
};
Flag m_iFlag; // 0: free, 1: occupied, 2: msg read but not freed (out-of-order), 3: msg dropped
// TODO: Transition to the new RcvBuffer allows to use bool here.

// TODO: The new RcvBuffer allows to use atomic_bool here.
sync::atomic<Flag> m_iFlag; // 0: free, 1: occupied, 2: msg read but not freed (out-of-order), 3: msg dropped
};

class CUnitQueue
{

public:
CUnitQueue();
/// @brief Construct a unit queue.
/// @param mss Initial number of units to allocate.
/// @param mss Maximum segment size meaning the size of each unit.
/// @throws CUDTException SRT_ENOBUF.
CUnitQueue(int initNumUnits, int mss);
~CUnitQueue();

public: // Storage size operations
/// Initialize the unit queue.
/// @param [in] size queue size
/// @param [in] mss maximum segment size
/// @return 0: success, -1: failure.
int init(int size, int mss);

/// Increase (double) the unit queue size.
/// @return 0: success, -1: failure.
int increase();

public:
int size() const { return m_iSize - m_iCount; }
int capacity() const { return m_iSize; }
int size() const { return m_iSize - m_iNumTaken; }

public: // Operations on units
/// find an available unit for incoming packet.
public:
/// @brief Find an available unit for incoming packet. Allocate new units if 90% or more are in use.
/// @note This function is not thread-safe. Currently only CRcvQueue::worker thread calls it, thus
/// it is not an issue. However, must be protected if used from several threads in the future.
/// @return Pointer to the available unit, NULL if not found.
CUnit* getNextAvailUnit();

Expand All @@ -121,16 +116,28 @@ class CUnitQueue
int m_iSize; // size of each queue

CQEntry* m_pNext;
} * m_pQEntry, // pointer to the first unit queue
*m_pCurrQueue, // pointer to the current available queue
*m_pLastQueue; // pointer to the last unit queue
};

CUnit* m_pAvailUnit; // recent available unit
/// Increase the unit queue size (by @a m_iBlockSize units).
/// Uses m_mtx to protect access and changes of the queue state.
/// @return 0: success, -1: failure.
int increase_();

int m_iSize; // total size of the unit queue, in number of packets
sync::atomic<int> m_iCount; // total number of valid (occupied) packets in the queue
/// @brief Allocated a CQEntry of iNumUnits with each unit of mss bytes.
/// @param iNumUnits a number of units to allocate
/// @param mss the size of each unit in bytes.
/// @return a pointer to a newly allocated entry on success, NULL otherwise.
static CQEntry* allocateEntry(const int iNumUnits, const int mss);

int m_iMSS; // unit buffer size
private:
CQEntry* m_pQEntry; // pointer to the first unit queue
CQEntry* m_pCurrQueue; // pointer to the current available queue
CQEntry* m_pLastQueue; // pointer to the last unit queue
CUnit* m_pAvailUnit; // recent available unit
int m_iSize; // total size of the unit queue, in number of packets
sync::atomic<int> m_iNumTaken; // total number of valid (occupied) packets in the queue
const int m_iMSS; // unit buffer size
const int m_iBlockSize; // Number of units in each CQEntry.

private:
CUnitQueue(const CUnitQueue&);
Expand Down Expand Up @@ -523,14 +530,14 @@ class CRcvQueue
EConnectStatus worker_ProcessAddressedPacket(int32_t id, CUnit* unit, const sockaddr_any& sa);

private:
CUnitQueue m_UnitQueue; // The received packet queue
CRcvUList* m_pRcvUList; // List of UDT instances that will read packets from the queue
CHash* m_pHash; // Hash table for UDT socket looking up
CChannel* m_pChannel; // UDP channel for receving packets
sync::CTimer* m_pTimer; // shared timer with the snd queue

int m_iIPversion; // IP version
size_t m_szPayloadSize; // packet payload size
CUnitQueue* m_pUnitQueue; // The received packet queue
CRcvUList* m_pRcvUList; // List of UDT instances that will read packets from the queue
CHash* m_pHash; // Hash table for UDT socket looking up
CChannel* m_pChannel; // UDP channel for receving packets
sync::CTimer* m_pTimer; // shared timer with the snd queue

int m_iIPversion; // IP version
size_t m_szPayloadSize; // packet payload size

sync::atomic<bool> m_bClosing; // closing the worker
#if ENABLE_LOGGING
Expand Down
7 changes: 3 additions & 4 deletions test/test_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,16 @@ class CRcvBufferReadMsg
void SetUp() override
{
// make_unique is unfortunatelly C++14
m_unit_queue = unique_ptr<CUnitQueue>(new CUnitQueue);
m_unit_queue.reset(new CUnitQueue(m_buff_size_pkts, 1500));
ASSERT_NE(m_unit_queue.get(), nullptr);
m_unit_queue->init(m_buff_size_pkts, 1500);

#if ENABLE_NEW_RCVBUFFER
const bool enable_msg_api = m_use_message_api;
const bool enable_peer_rexmit = true;
m_rcv_buffer = unique_ptr<CRcvBufferNew>(new CRcvBufferNew(m_init_seqno, m_buff_size_pkts, m_unit_queue.get(), enable_msg_api));
m_rcv_buffer.reset(new CRcvBufferNew(m_init_seqno, m_buff_size_pkts, m_unit_queue.get(), enable_msg_api));
m_rcv_buffer->setPeerRexmitFlag(enable_peer_rexmit);
#else
m_rcv_buffer = unique_ptr<CRcvBuffer>(new CRcvBuffer(m_unit_queue.get(), m_buff_size_pkts));
m_rcv_buffer.reset(new CRcvBuffer(m_unit_queue.get(), m_buff_size_pkts));
#endif
ASSERT_NE(m_rcv_buffer.get(), nullptr);
}
Expand Down
Loading

0 comments on commit a51ec39

Please sign in to comment.