Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a registry of groups to protect access from a callback #95

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions xtransmit/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class scheduler
thread_.join();
}

void stop()
{
done_ = true;
}

template <typename Callable, typename... Args>
void schedule_on(const steady_clock::time_point time, Callable&& f, Args&&... args)
{
Expand All @@ -76,7 +81,6 @@ class scheduler
} sync_;

multimap<steady_clock::time_point, shared_ptr<task>> tasks_;
mutex lock_;
thread thread_;

void timer_loop()
Expand All @@ -102,14 +106,14 @@ class scheduler

void add_task(const steady_clock::time_point time, shared_ptr<task> t)
{
lock_guard<mutex> l(lock_);
lock_guard<mutex> l(sync_.mtx);
tasks_.emplace(time, move(t));
sync_.cv.notify_one();
}

void manage_tasks()
{
lock_guard<mutex> l(lock_);
lock_guard<mutex> l(sync_.mtx);

auto end_of_tasks_to_run = tasks_.upper_bound(steady_clock::now());
if (end_of_tasks_to_run != tasks_.begin())
Expand Down
78 changes: 71 additions & 7 deletions xtransmit/srt_socket_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,45 @@ namespace srt_logging
std::string SockStatusStr(SRT_SOCKSTATUS);
}

namespace xtransmit
{
namespace details
{
class group_registry
{
public:
void add(intptr_t p)
{
std::lock_guard<std::mutex> lck(m_mtx);
m_groups.emplace(p);
}

void remove(intptr_t p)
{
std::lock_guard<std::mutex> lck(m_mtx);
m_groups.erase(p);
}

class not_found : public std::runtime_error { public: not_found(const char* m) : std::runtime_error(m) {} };

std::unique_lock<std::mutex> scoped_lock(intptr_t p) const
{
std::unique_lock<std::mutex> lck(m_mtx);
if (!m_groups.count(p))
throw not_found("");
return lck; // Compiler will perform an RVO or move.
}

private:
mutable std::mutex m_mtx;
std::set<intptr_t> m_groups;
};


static group_registry g_group_registry;
}
}

#define LOG_SRT_GROUP "SOCKET::SRT_GROUP "

SocketOption::Mode detect_srt_mode(const UriParser& uri)
Expand Down Expand Up @@ -185,6 +224,8 @@ socket::srt_group::srt_group(const vector<UriParser>& uris)
spdlog::trace(LOG_SRT_GROUP "Creating a group of callers (type {}).", gtype_str);
create_callers(uris, gtype);
}

details::g_group_registry.add((intptr_t) this);
}

socket::srt_group::srt_group(srt_group& group, int group_id)
Expand All @@ -202,10 +243,13 @@ socket::srt_group::srt_group(srt_group& group, int group_id)
if (SRT_ERROR == srt_epoll_add_usock(m_epoll_io, m_bind_socket, &io_modes))
throw socket::exception(srt_getlasterror_str());
}

details::g_group_registry.add((intptr_t)this);
}

socket::srt_group::~srt_group()
{
m_scheduler.stop();
if (!m_blocking_mode)
{
spdlog::debug(LOG_SRT_GROUP "@{} Closing. Releasing epolls", m_bind_socket);
Expand All @@ -214,7 +258,9 @@ socket::srt_group::~srt_group()
if (m_epoll_io != -1)
srt_epoll_release(m_epoll_io);
}

spdlog::debug(LOG_SRT_GROUP "@{} Closing SRT group", m_bind_socket);
details::g_group_registry.remove((intptr_t)this);
release_targets();
release_listeners();
srt_close(m_bind_socket);
Expand Down Expand Up @@ -442,9 +488,20 @@ int socket::srt_group::listen_callback_fn(void* opaq, SRTSOCKET sock, int hsvers
netaddr_any host(host_sa.get(), host_sa_len);
spdlog::trace(LOG_SRT_GROUP "Accepted member socket @{}, host IP {}, remote IP {}", sock, host.str(), sa.str());

// TODO: this group may no longer exist. Use some global array to track valid groups.
socket::srt_group* group = reinterpret_cast<socket::srt_group*>(opaq);
return group->on_listen_callback(sock);

try
{
// The group passed via 'opaq' may no longer exist. The g_group_registry checks and holds the lifetime.
auto lck = details::g_group_registry.scoped_lock((intptr_t)opaq);
socket::srt_group* group = reinterpret_cast<socket::srt_group*>(opaq);
return group->on_listen_callback(sock);
}
catch (const details::group_registry::not_found&)
{
spdlog::warn(LOG_SRT_GROUP "listen_callback_fn: group has already been destructed.");
}

return 0;
}

void socket::srt_group::set_listen_callback()
Expand All @@ -464,10 +521,17 @@ void socket::srt_group::connect_callback_fn(void* opaq, SRTSOCKET sock, int erro
return;
}

// TODO: this group may no longer exist. Use some global array to track valid groups.
socket::srt_group* group = reinterpret_cast<socket::srt_group*>(opaq);

group->on_connect_callback(sock, error, peer, token);
try
{
// The group passed via 'opaq' may no longer exist. The g_group_registry checks and holds the lifetime.
auto lck = details::g_group_registry.scoped_lock((intptr_t)opaq);
socket::srt_group* group = reinterpret_cast<socket::srt_group*>(opaq);
return group->on_connect_callback(sock, error, peer, token);
}
catch (const details::group_registry::not_found&)
{
spdlog::warn(LOG_SRT_GROUP "connect_callback_fn: group has already been destructed.");
}
}

void socket::srt_group::on_connect_callback(SRTSOCKET sock, int error, const sockaddr* /*peer*/, int token)
Expand Down
Loading