diff --git a/srtcore/api.cpp b/srtcore/api.cpp index 56c581fec..3c2db3c45 100644 --- a/srtcore/api.cpp +++ b/srtcore/api.cpp @@ -541,6 +541,7 @@ int srt::CUDTUnited::newConnection(const SRTSOCKET listen, try { + ScopedLock col(ls->core().m_ConnectionLock); ns = new CUDTSocket(*ls); // No need to check the peer, this is the address from which the request has come. ns->m_PeerAddr = peer; diff --git a/srtcore/queue.cpp b/srtcore/queue.cpp index 98999a81f..5d0ee465b 100644 --- a/srtcore/queue.cpp +++ b/srtcore/queue.cpp @@ -1130,7 +1130,7 @@ srt::CRcvQueue::CRcvQueue() , m_szPayloadSize() , m_bClosing(false) , m_LSLock() - , m_pListener(NULL) + , m_pListener() , m_pRendezvousQueue(NULL) , m_vNewEntry() , m_IDLock() @@ -1405,10 +1405,11 @@ srt::EConnectStatus srt::CRcvQueue::worker_ProcessConnectionRequest(CUnit* unit, bool have_listener = false; { ScopedLock cg(m_LSLock); - if (m_pListener) + m_pListener.lockRead(); + if (m_pListener.udt) { - LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener->socketID()); - listener_ret = m_pListener->processConnectRequest(addr, unit->m_Packet); + LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener.udt->socketID()); + listener_ret = m_pListener.udt->processConnectRequest(addr, unit->m_Packet); // This function does return a code, but it's hard to say as to whether // anything can be done about it. In case when it's stated possible, the @@ -1418,6 +1419,7 @@ srt::EConnectStatus srt::CRcvQueue::worker_ProcessConnectionRequest(CUnit* unit, have_listener = true; } + m_pListener.unlockRead(); } // NOTE: Rendezvous sockets do bind(), but not listen(). It means that the socket is @@ -1690,21 +1692,28 @@ int srt::CRcvQueue::recvfrom(int32_t id, CPacket& w_packet) int srt::CRcvQueue::setListener(CUDT* u) { - ScopedLock lslock(m_LSLock); - - if (NULL != m_pListener) + m_pListener.lockWrite(); + if (NULL != m_pListener.udt) + { + m_pListener.unlockWrite(); return -1; + } + + m_pListener.udt = u; + m_pListener.unlockWrite(); - m_pListener = u; return 0; } void srt::CRcvQueue::removeListener(const CUDT* u) { - ScopedLock lslock(m_LSLock); - - if (u == m_pListener) - m_pListener = NULL; + //ScopedLock lslock(m_LSLock); + m_pListener.lockWrite(); + if (u == m_pListener.udt) + { + m_pListener.udt = NULL; + } + m_pListener.unlockWrite(); } void srt::CRcvQueue::registerConnector(const SRTSOCKET& id, diff --git a/srtcore/queue.h b/srtcore/queue.h index dd68a7721..addcd665b 100644 --- a/srtcore/queue.h +++ b/srtcore/queue.h @@ -67,6 +67,36 @@ namespace srt { class CChannel; class CUDT; +class CUDTWrapper; + +class CUDTWrapper { +public: + CUDT *udt; + sync::SharedMutex mut; + +public: + CUDTWrapper() + :udt(NULL) + ,mut() + { +} +void lockRead() +{ + return mut.lockRead(); +} +void lockWrite() +{ + return mut.lockWrite(); +} +void unlockRead() +{ + return mut.unlockRead(); + +} +void unlockWrite(){ + return mut.unlockWrite(); +} +}; struct CUnit { @@ -555,7 +585,7 @@ class CRcvQueue private: sync::Mutex m_LSLock; - CUDT* m_pListener; // pointer to the (unique, if any) listening UDT entity + CUDTWrapper m_pListener; // pointer to the (unique, if any) listening UDT entity CRendezvousQueue* m_pRendezvousQueue; // The list of sockets in rendezvous mode std::vector m_vNewEntry; // newly added entries, to be inserted diff --git a/srtcore/sync.h b/srtcore/sync.h index fb6d56432..4f21d7375 100644 --- a/srtcore/sync.h +++ b/srtcore/sync.h @@ -12,6 +12,7 @@ #define INC_SRT_SYNC_H #include "platform_sys.h" +#include #include #include @@ -943,9 +944,164 @@ CUDTException& GetThreadLocalError(); /// @param[in] maxVal maximum allowed value of the resulting random number. int genRandomInt(int minVal, int maxVal); +class SharedMutex +{ + private: + Condition m_pLockWriteCond; + Condition m_pLockReadCond; + + Mutex m_pMutex; + Mutex m_pMutex2; + + int m_pCountRead; + bool m_pWriterLocked; + + + public: + SharedMutex() + :m_pLockWriteCond() + ,m_pLockReadCond() + ,m_pMutex() + ,m_pMutex2() + ,m_pCountRead(0) + ,m_pWriterLocked(false) + { + m_pCountRead = 0; + m_pWriterLocked = false; + + } + + void lockWrite() + { + UniqueLock l1(m_pMutex); + if(m_pWriterLocked) + m_pLockWriteCond.wait(l1); + m_pWriterLocked = true; + if(m_pCountRead) + m_pLockReadCond.wait(l1); + + + } + + void unlockWrite() + { + UniqueLock l2(m_pMutex); + m_pWriterLocked = false; + l2.unlock(); + std::cout << "NOTIFY ALL" << std::endl; + m_pLockWriteCond.notify_all(); + std::cout << "WRITER NOTIFIED" << std::endl; + + } + + void lockRead() + { + std::cout << "TRY LOCK READ " << this->m_pCountRead << this->m_pWriterLocked << std::endl; + UniqueLock l3(m_pMutex); + if(m_pWriterLocked) + m_pLockWriteCond.wait(l3); + m_pCountRead++; + std::cout << "LOCKED READ" << std::endl; + } + + void unlockRead() + { + std::cout << "UNLOCK READ" << std::endl; + ScopedLock l4(m_pMutex); + m_pCountRead--; + if(m_pWriterLocked && m_pCountRead == 0) + m_pLockReadCond.notify_one(); + else if (m_pCountRead > 0) + m_pLockWriteCond.notify_one(); + std::cout << "READ UNLOCKED" << std::endl; + + + } + +}; + +/* REFERENCE IMPLEMENTATION +class shared_mutex +{ + Mutex mut_; + Condition gate1_; + Condition gate2_; + unsigned state_; + + static const unsigned write_entered_ = 1U << (sizeof(unsigned)*CHAR_BIT - 1); + static const unsigned n_readers_ = ~write_entered_; + +public: + + shared_mutex() : state_(0) {} + + +// Exclusive ownership + +void +lock() +{ + UniqueLock lk(mut_); + std::cout << "LOCK WRITE " << std::endl; + while (state_ & write_entered_) + gate1_.wait(lk); + state_ |= write_entered_; + while (state_ & n_readers_) + gate2_.wait(lk); + std::cout << "LOCK WRITE DONE" << std::endl; + +} + +void +unlock() +{ + { + ScopedLock _(mut_); + state_ = 0; + } + std::cout << "UNLOCK WRITE " << std::endl; + gate1_.notify_all(); + std::cout << "UNLOCK WRITE DONE" << std::endl; + +} + +// Shared ownership + +void +lock_shared() +{ + UniqueLock lk(mut_); + while ((state_ & write_entered_) || (state_ & n_readers_) == n_readers_) + gate1_.wait(lk); + unsigned num_readers = (state_ & n_readers_) + 1; + state_ &= ~n_readers_; + state_ |= num_readers; +} + +void +unlock_shared() +{ + ScopedLock _(mut_); + unsigned num_readers = (state_ & n_readers_) - 1; + state_ &= ~n_readers_; + state_ |= num_readers; + if (state_ & write_entered_) + { + if (num_readers == 0) + gate2_.notify_one(); + } + else + { + if (num_readers == n_readers_ - 1) + gate1_.notify_one(); + } +} +};*/ + } // namespace sync } // namespace srt + #include "atomic_clock.h" #endif // INC_SRT_SYNC_H