Skip to content

Commit

Permalink
Stop using winrt cancellation because it doesn't work, use std::stop_…
Browse files Browse the repository at this point in the history
…token and a jthread instead. This is massively simpler and seems to just work.
  • Loading branch information
leeter committed Nov 25, 2022
1 parent 09a6528 commit 706a380
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 38 deletions.
44 changes: 19 additions & 25 deletions WinMTRDialog.ixx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import <memory>;
import <mutex>;
import <optional>;
import <atomic>;
import <thread>;
import WinMTROptionsProvider;
import WinMTRStatusBar;
import WinMTR.Net;
Expand Down Expand Up @@ -138,12 +139,8 @@ private:
CButton m_buttonExpH;
std::wstring msz_defaulthostname;
std::shared_ptr<WinMTRNet> wmtrnet;
struct tracer_lacky {
winrt::Windows::Foundation::IAsyncAction tracer;
winrt::apartment_context context;
};
std::mutex tracer_mutex;
std::optional<tracer_lacky> trace_lacky;
std::optional<std::jthread> trace_lacky;
HICON m_hIcon;
double interval;
STATES state;
Expand All @@ -162,7 +159,7 @@ private:
std::atomic_bool tracing;

void ClearHistory();
winrt::Windows::Foundation::IAsyncAction pingThread(std::wstring shost);
winrt::Windows::Foundation::IAsyncAction pingThread(std::stop_token token, std::wstring shost);
winrt::fire_and_forget stopTrace();
public:

Expand Down Expand Up @@ -213,6 +210,7 @@ import <iterator>;
import <string_view>;
import <fstream>;
import <format>;
import <stop_token>;
import WinMTRIPUtils;
import WinMTRSNetHost;
import WinMTRDnsUtil;
Expand Down Expand Up @@ -1138,7 +1136,7 @@ void WinMTRDialog::ClearHistory()
m_comboHost.AddString(CString((LPCSTR)IDS_STRING_CLEAR_HISTORY));
}

winrt::Windows::Foundation::IAsyncAction WinMTRDialog::pingThread(std::wstring sHost)
winrt::Windows::Foundation::IAsyncAction WinMTRDialog::pingThread(std::stop_token stop_token, std::wstring sHost)
{
if (tracing.exchange(true)) {
throw new std::runtime_error("Tracing started twice!");
Expand All @@ -1152,9 +1150,6 @@ winrt::Windows::Foundation::IAsyncAction WinMTRDialog::pingThread(std::wstring s

SOCKADDR_STORAGE addrstore = {};


const auto cancellation = co_await winrt::get_cancellation_token();
cancellation.enable_propagation();
for (auto af : { AF_INET, AF_INET6 }) {
INT addrSize = sizeof(addrstore);
if (auto res = WSAStringToAddressW(
Expand All @@ -1164,7 +1159,7 @@ winrt::Windows::Foundation::IAsyncAction WinMTRDialog::pingThread(std::wstring s
, reinterpret_cast<LPSOCKADDR>(&addrstore)
, &addrSize);
!res) {
co_await this->wmtrnet->DoTrace(*reinterpret_cast<LPSOCKADDR>(&addrstore));
co_await this->wmtrnet->DoTrace(stop_token, *reinterpret_cast<LPSOCKADDR>(&addrstore));
co_return;
}
}
Expand All @@ -1182,21 +1177,24 @@ winrt::Windows::Foundation::IAsyncAction WinMTRDialog::pingThread(std::wstring s
co_return;
}
addrstore = result->front();
co_await this->wmtrnet->DoTrace(*reinterpret_cast<LPSOCKADDR>(&addrstore));
co_await this->wmtrnet->DoTrace(stop_token, *reinterpret_cast<LPSOCKADDR>(&addrstore));
}

winrt::fire_and_forget WinMTRDialog::stopTrace()
{
std::optional<tracer_lacky> temp;
// grab the thread under a mutex so we don't mess this up and cause a data race
decltype(trace_lacky) temp;
{
std::unique_lock lock(this->tracer_mutex);
std::swap(temp, this->trace_lacky);
std::unique_lock trace_lock{ tracer_mutex };
std::swap(temp, trace_lacky);
}
// don't bother trying call something not there
if (!temp) {
co_return;
}
co_await temp->context;
temp->tracer.Cancel();
co_await winrt::resume_background();
temp.reset(); //trigger the stop token
co_return;
}

void WinMTRDialog::OnCbnSelendokComboHost()
Expand Down Expand Up @@ -1296,13 +1294,14 @@ void WinMTRDialog::Transit(STATES new_state)
if (sHost.IsEmpty()) [[unlikely]] { // Technically never because this is caught in the calling function
sHost = L"localhost";
}
auto thread = std::thread([this](auto sHost) noexcept {
std::unique_lock trace_lock{ tracer_mutex };
// create the jthread and stop token all in one go
trace_lacky.emplace([this](std::stop_token stop_token, auto sHost) noexcept {
winrt::init_apartment(winrt::apartment_type::multi_threaded);
try {
auto tracer_local = this->pingThread(sHost);
auto tracer_local = this->pingThread(stop_token, sHost);
{
std::unique_lock lock(this->tracer_mutex);
this->trace_lacky.emplace(tracer_local, winrt::apartment_context());
}
// keep the thread alive
tracer_local.get();
Expand All @@ -1313,12 +1312,7 @@ void WinMTRDialog::Transit(STATES new_state)
catch (winrt::hresult_illegal_method_call const&){
// don't care this happens
}
{
std::unique_lock lock(this->tracer_mutex);
this->trace_lacky.reset();
}
}, std::wstring(sHost));
thread.detach();
}
m_buttonStart.EnableWindow(TRUE);
break;
Expand Down
28 changes: 15 additions & 13 deletions WinMTRNet.ixx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import <memory>;
import <vector>;
import <optional>;
import <winrt/Windows.Foundation.h>;

import <stop_token>;
import WinMTROptionsProvider;
import winmtr.helper;

Expand Down Expand Up @@ -149,7 +149,7 @@ public:


[[nodiscard("The task should be awaited")]]
winrt::Windows::Foundation::IAsyncAction DoTrace(sockaddr& address);
winrt::Windows::Foundation::IAsyncAction DoTrace(std::stop_token stop_token, sockaddr& address);

void ResetHops() noexcept
{
Expand Down Expand Up @@ -211,7 +211,7 @@ private:

template<class T>
[[nodiscard("The task should be awaited")]]
winrt::Windows::Foundation::IAsyncAction handleICMP(trace_thread current);
winrt::Windows::Foundation::IAsyncAction handleICMP(std::stop_token stop_token, trace_thread current);
};

module : private;
Expand Down Expand Up @@ -257,30 +257,28 @@ int WinMTRNet::GetMax() const
}

[[nodiscard("The task should be awaited")]]
winrt::Windows::Foundation::IAsyncAction WinMTRNet::DoTrace(sockaddr& address)
winrt::Windows::Foundation::IAsyncAction WinMTRNet::DoTrace(std::stop_token stop_token, sockaddr& address)
{
auto cancellation = co_await winrt::get_cancellation_token();
cancellation.enable_propagation();
tracing = true;
ResetHops();
last_remote_addr = {};
std::memcpy(&last_remote_addr, &address, getAddressSize(address));

auto threadMaker = [&address, this](UCHAR i) {
auto threadMaker = [&address, this, stop_token](UCHAR i) {
trace_thread current(address.sa_family, this, i + 1);
using namespace std::string_view_literals;
TRACE_MSG(L"Thread with TTL="sv << current.ttl << L" started."sv);
std::memcpy(&current.address, &address, getAddressSize(address));
if (current.address.ss_family == AF_INET) {
return this->handleICMP<sockaddr_in>(std::move(current));
return this->handleICMP<sockaddr_in>(stop_token, std::move(current));
}
else if (current.address.ss_family == AF_INET6) {
return this->handleICMP<sockaddr_in6>(std::move(current));
return this->handleICMP<sockaddr_in6>(stop_token, std::move(current));
}
winrt::throw_hresult(HRESULT_FROM_WIN32(WSAEOPNOTSUPP));
};

cancellation.callback([this]() noexcept {
std::stop_callback callback(stop_token, [this]() noexcept {
this->tracing = false;
TRACE_MSG(L"Cancellation");
});
Expand All @@ -293,7 +291,7 @@ winrt::Windows::Foundation::IAsyncAction WinMTRNet::DoTrace(sockaddr& address)

template<class T>
[[nodiscard("The task should be awaited")]]
winrt::Windows::Foundation::IAsyncAction WinMTRNet::handleICMP(trace_thread current) {
winrt::Windows::Foundation::IAsyncAction WinMTRNet::handleICMP(std::stop_token stop_token, trace_thread current) {
using namespace std::literals;
using traits = icmp_ping_traits<T>;
trace_thread mine = std::move(current);
Expand All @@ -303,11 +301,14 @@ winrt::Windows::Foundation::IAsyncAction WinMTRNet::handleICMP(trace_thread curr
std::vector<std::byte> achReqData(nDataLen, static_cast<std::byte>(32)); //whitespaces
std::vector<std::byte> achRepData(reply_reply_buffer_size<T>(nDataLen));

auto cancellation = co_await winrt::get_cancellation_token();

T* addr = reinterpret_cast<T*>(&mine.address);
while (this->tracing) {
if (cancellation()) [[unlikely]]

// this is a backup for if the atomic above doesn't work
if (stop_token.stop_requested()) [[unlikely]]
{
this->tracing = false;
co_return;
}
// For some strange reason, ICMP API is not filling the TTL for icmp echo reply
Expand Down Expand Up @@ -402,6 +403,7 @@ winrt::Windows::Foundation::IAsyncAction WinMTRNet::handleICMP(trace_thread curr
}

} /* end ping loop */
co_return;
}

void WinMTRNet::SetAddr(int at, sockaddr& addr)
Expand Down

0 comments on commit 706a380

Please sign in to comment.