diff --git a/.github/actions/spelling/allow/allow.txt b/.github/actions/spelling/allow/allow.txt index 7bff8f0cfa0..bccfe086aeb 100644 --- a/.github/actions/spelling/allow/allow.txt +++ b/.github/actions/spelling/allow/allow.txt @@ -1,3 +1,4 @@ +aci admins allcolors Apc @@ -8,6 +9,7 @@ breadcrumbs bsd calt ccmp +ccon changelog clickable clig @@ -91,6 +93,7 @@ reserialize reserializes rlig runtimes +servicebus shcha slnt Sos @@ -117,6 +120,7 @@ vsdevcmd walkthrough walkthroughs We'd +westus wildcards XBox YBox diff --git a/src/cascadia/TerminalConnection/AzureConnection.cpp b/src/cascadia/TerminalConnection/AzureConnection.cpp index 3844979b488..708518509b5 100644 --- a/src/cascadia/TerminalConnection/AzureConnection.cpp +++ b/src/cascadia/TerminalConnection/AzureConnection.cpp @@ -398,6 +398,7 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation switch (bufferType) { + case WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE: case WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE: case WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE: { @@ -797,7 +798,7 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation // - an optional HTTP method (defaults to POST if content is present, GET otherwise) // Return value: // - the response from the server as a json value - WDJ::JsonObject AzureConnection::_SendRequestReturningJson(std::wstring_view uri, const WWH::IHttpContent& content, WWH::HttpMethod method) + WDJ::JsonObject AzureConnection::_SendRequestReturningJson(std::wstring_view uri, const WWH::IHttpContent& content, WWH::HttpMethod method, const Windows::Foundation::Uri referer) { if (!method) { @@ -810,6 +811,11 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation auto headers{ request.Headers() }; headers.Accept().TryParseAdd(L"application/json"); + if (referer) + { + headers.Referer(referer); + } + const auto response{ _httpClient.SendRequestAsync(request).get() }; const auto string{ response.Content().ReadAsStringAsync().get() }; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; @@ -974,17 +980,56 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation auto uri{ fmt::format(L"{}terminals?cols={}&rows={}&version=2019-01-01&shell={}", _cloudShellUri, _initialCols, _initialRows, shellType) }; WWH::HttpStringContent content{ - L"", + L"{}", WSS::UnicodeEncoding::Utf8, // LOAD-BEARING. the API returns "'content-type' should be 'application/json' or 'multipart/form-data'" L"application/json" }; - const auto terminalResponse = _SendRequestReturningJson(uri, content); + const auto terminalResponse = _SendRequestReturningJson(uri, content, WWH::HttpMethod::Post(), Windows::Foundation::Uri(_cloudShellUri)); _terminalID = terminalResponse.GetNamedString(L"id"); + // we have to do some post-handling to get the proper socket endpoint + // the logic here is based on the way the cloud shell team itself does it + winrt::hstring finalSocketUri; + const std::wstring_view wCloudShellUri{ _cloudShellUri }; + + if (wCloudShellUri.find(L"servicebus") == std::wstring::npos) + { + // wCloudShellUri does not contain the word "servicebus", we can just use it to make the final URI + + // remove the "https" from the cloud shell URI + const auto uriWithoutProtocol = wCloudShellUri.substr(5); + + finalSocketUri = fmt::format(FMT_COMPILE(L"wss{}terminals/{}"), uriWithoutProtocol, _terminalID); + } + else + { + // if wCloudShellUri contains the word "servicebus", that means the returned socketUri is of the form + // wss://ccon-prod-westus-aci-03.servicebus.windows.net/cc-AAAA-AAAAAAAA//aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + // we need to change it to: + // wss://ccon-prod-westus-aci-03.servicebus.windows.net/$hc/cc-AAAA-AAAAAAAA/terminals/aaaaaaaaaaaaaaaaaaaaaa + + const auto socketUri = terminalResponse.GetNamedString(L"socketUri"); + const std::wstring_view wSocketUri{ socketUri }; + + // get the substring up until the ".net" + const auto dotNetStart = wSocketUri.find(L".net"); + THROW_HR_IF(E_UNEXPECTED, dotNetStart == std::wstring::npos); + const auto dotNetEnd = dotNetStart + 4; + const auto wSocketUriBody = wSocketUri.substr(0, dotNetEnd); + + // get the portion between the ".net" and the "//" (this is the cc-AAAA-AAAAAAAA part) + const auto lastDoubleSlashPos = wSocketUri.find_last_of(L"//"); + THROW_HR_IF(E_UNEXPECTED, lastDoubleSlashPos == std::wstring::npos); + const auto wSocketUriMiddle = wSocketUri.substr(dotNetEnd, lastDoubleSlashPos - (dotNetEnd)); + + // piece together the final uri, adding in the "$hc" and "terminals" where needed + finalSocketUri = fmt::format(FMT_COMPILE(L"{}/$hc{}terminals/{}"), wSocketUriBody, wSocketUriMiddle, _terminalID); + } + // Return the uri - return terminalResponse.GetNamedString(L"socketUri"); + return winrt::hstring{ finalSocketUri }; } // Method description: diff --git a/src/cascadia/TerminalConnection/AzureConnection.h b/src/cascadia/TerminalConnection/AzureConnection.h index cedd76757af..9d179073019 100644 --- a/src/cascadia/TerminalConnection/AzureConnection.h +++ b/src/cascadia/TerminalConnection/AzureConnection.h @@ -68,7 +68,7 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation void _WriteStringWithNewline(const std::wstring_view str); void _WriteCaughtExceptionRecord(); - winrt::Windows::Data::Json::JsonObject _SendRequestReturningJson(std::wstring_view uri, const winrt::Windows::Web::Http::IHttpContent& content = nullptr, winrt::Windows::Web::Http::HttpMethod method = nullptr); + winrt::Windows::Data::Json::JsonObject _SendRequestReturningJson(std::wstring_view uri, const winrt::Windows::Web::Http::IHttpContent& content = nullptr, winrt::Windows::Web::Http::HttpMethod method = nullptr, const winrt::Windows::Foundation::Uri referer = nullptr); void _setAccessToken(std::wstring_view accessToken); winrt::Windows::Data::Json::JsonObject _GetDeviceCode(); winrt::Windows::Data::Json::JsonObject _WaitForUser(const winrt::hstring& deviceCode, int pollInterval, int expiresIn);