-
Notifications
You must be signed in to change notification settings - Fork 8.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow OpenAI to be used with Terminal Chat #17540
base: feature/llm
Are you sure you want to change the base?
Changes from all commits
b44216b
e78e4d0
ac83d76
4b944e9
ef406ee
4fb4ca4
e5afbae
32b3d68
1700a92
08dd951
1934b30
96a489a
8e560e2
f9e9326
4133239
e545ca2
dfad8d9
5139b88
dfbdc75
1f305ab
242b964
0094db0
6479c47
d392aab
8faf8b4
dd6b46d
4f81775
46ab508
ac1b4a3
fc8e36d
fdee6c2
fff97a4
45ce94d
c89a306
de88842
edf08a9
830e655
16d3035
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
aci | ||
AIIs | ||
AILLM | ||
allcolors | ||
breadcrumb | ||
breadcrumbs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
Check failure Code scanning / check-spelling Check File Path
[AILLM](#security-tab) is not a recognized word. \(check-file-path\)
|
||
// Licensed under the MIT license. | ||
|
||
#include "pch.h" | ||
#include "OpenAILLMProvider.h" | ||
#include "../../types/inc/utils.hpp" | ||
#include "LibraryResources.h" | ||
|
||
#include "OpenAILLMProvider.g.cpp" | ||
|
||
using namespace winrt::Windows::Foundation; | ||
using namespace winrt::Windows::Foundation::Collections; | ||
using namespace winrt::Windows::UI::Core; | ||
using namespace winrt::Windows::UI::Xaml; | ||
using namespace winrt::Windows::UI::Xaml::Controls; | ||
using namespace winrt::Windows::System; | ||
namespace WWH = ::winrt::Windows::Web::Http; | ||
namespace WSS = ::winrt::Windows::Storage::Streams; | ||
namespace WDJ = ::winrt::Windows::Data::Json; | ||
PankajBhojwani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
static constexpr std::wstring_view applicationJson{ L"application/json" }; | ||
static constexpr std::wstring_view acceptedModel{ L"gpt-3.5-turbo" }; | ||
|
||
static constexpr std::wstring_view openAIEndpoint{ L"https://api.openai.com/v1/chat/completions" }; | ||
|
||
namespace winrt::Microsoft::Terminal::Query::Extension::implementation | ||
{ | ||
void OpenAILLMProvider::SetAuthentication(const Windows::Foundation::Collections::ValueSet& authValues) | ||
{ | ||
_AIKey = unbox_value_or<hstring>(authValues.TryLookup(L"key").try_as<IPropertyValue>(), L""); | ||
_httpClient = winrt::Windows::Web::Http::HttpClient{}; | ||
_httpClient.DefaultRequestHeaders().Accept().TryParseAdd(applicationJson); | ||
_httpClient.DefaultRequestHeaders().Authorization(WWH::Headers::HttpCredentialsHeaderValue{ L"Bearer", _AIKey }); | ||
} | ||
|
||
void OpenAILLMProvider::ClearMessageHistory() | ||
{ | ||
_jsonMessages.Clear(); | ||
} | ||
|
||
void OpenAILLMProvider::SetSystemPrompt(const winrt::hstring& systemPrompt) | ||
{ | ||
WDJ::JsonObject systemMessageObject; | ||
winrt::hstring systemMessageContent{ systemPrompt }; | ||
systemMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"system")); | ||
systemMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(systemMessageContent)); | ||
_jsonMessages.Append(systemMessageObject); | ||
} | ||
|
||
void OpenAILLMProvider::SetContext(const Extension::IContext context) | ||
{ | ||
_context = context; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is where we'd ideally take the argument by-copy and then move it into place. This lets the caller move the instance into the argument which you then move into the member. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh and: Moving only works if the argument is not const. If it's const it'll not fail to compile, because... C++ reasons. Instead, it'll silently just create a copy. Bu With high enough compiler warnings it'll warn you about it though, but idk if it's enabled for this project. |
||
} | ||
|
||
winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> OpenAILLMProvider::GetResponseAsync(const winrt::hstring userPrompt) | ||
{ | ||
// Use the ErrorTypes enum to flag whether the response the user receives is an error message | ||
// we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) | ||
ErrorTypes errorType{ ErrorTypes::None }; | ||
hstring message{}; | ||
|
||
// Make sure we are on the background thread for the http request | ||
co_await winrt::resume_background(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also needs the strong-this pointer for safety. You can use the weak/strong handling that we do elsewhere or hold onto a strong pointer the entire time. That's pretty much entirely up to you (the latter is probably better though). |
||
|
||
WWH::HttpRequestMessage request{ WWH::HttpMethod::Post(), Uri{ openAIEndpoint } }; | ||
request.Headers().Accept().TryParseAdd(applicationJson); | ||
|
||
WDJ::JsonObject jsonContent; | ||
WDJ::JsonObject messageObject; | ||
|
||
winrt::hstring engineeredPrompt{ userPrompt }; | ||
if (_context && !_context.ActiveCommandline().empty()) | ||
{ | ||
engineeredPrompt = userPrompt + L". The shell I am running is " + _context.ActiveCommandline(); | ||
PankajBhojwani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
messageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"user")); | ||
messageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(engineeredPrompt)); | ||
_jsonMessages.Append(messageObject); | ||
jsonContent.SetNamedValue(L"model", WDJ::JsonValue::CreateStringValue(acceptedModel)); | ||
jsonContent.SetNamedValue(L"messages", _jsonMessages); | ||
jsonContent.SetNamedValue(L"temperature", WDJ::JsonValue::CreateNumberValue(0)); | ||
const auto stringContent = jsonContent.ToString(); | ||
WWH::HttpStringContent requestContent{ | ||
stringContent, | ||
WSS::UnicodeEncoding::Utf8, | ||
applicationJson | ||
}; | ||
|
||
request.Content(requestContent); | ||
|
||
// Send the request | ||
try | ||
{ | ||
const auto response = _httpClient.SendRequestAsync(request).get(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. co_await? |
||
// Parse out the suggestion from the response | ||
const auto string{ response.Content().ReadAsStringAsync().get() }; | ||
const auto jsonResult{ WDJ::JsonObject::Parse(string) }; | ||
if (jsonResult.HasKey(L"error")) | ||
{ | ||
const auto errorObject = jsonResult.GetNamedObject(L"error"); | ||
message = errorObject.GetNamedString(L"message"); | ||
errorType = ErrorTypes::FromProvider; | ||
} | ||
else | ||
{ | ||
const auto choices = jsonResult.GetNamedArray(L"choices"); | ||
const auto firstChoice = choices.GetAt(0).GetObject(); | ||
const auto messageObject = firstChoice.GetNamedObject(L"message"); | ||
message = messageObject.GetNamedString(L"content"); | ||
} | ||
} | ||
catch (...) | ||
{ | ||
message = RS_(L"UnknownErrorMessage"); | ||
errorType = ErrorTypes::Unknown; | ||
} | ||
|
||
// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far | ||
WDJ::JsonObject responseMessageObject; | ||
responseMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"assistant")); | ||
responseMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(message)); | ||
_jsonMessages.Append(responseMessageObject); | ||
|
||
co_return winrt::make<OpenAIResponse>(message, errorType); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
Check failure Code scanning / check-spelling Check File Path
[AILLM](#security-tab) is not a recognized word. \(check-file-path\)
|
||
// Licensed under the MIT license. | ||
|
||
#pragma once | ||
|
||
#include "OpenAILLMProvider.g.h" | ||
|
||
namespace winrt::Microsoft::Terminal::Query::Extension::implementation | ||
{ | ||
struct OpenAILLMProvider : OpenAILLMProviderT<OpenAILLMProvider> | ||
{ | ||
OpenAILLMProvider() = default; | ||
|
||
void ClearMessageHistory(); | ||
void SetSystemPrompt(const winrt::hstring& systemPrompt); | ||
void SetContext(const Extension::IContext context); | ||
|
||
winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> GetResponseAsync(const winrt::hstring userPrompt); | ||
|
||
void SetAuthentication(const Windows::Foundation::Collections::ValueSet& authValues); | ||
TYPED_EVENT(AuthChanged, winrt::Microsoft::Terminal::Query::Extension::ILMProvider, Windows::Foundation::Collections::ValueSet); | ||
|
||
private: | ||
winrt::hstring _AIKey; | ||
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; | ||
|
||
Extension::IContext _context; | ||
|
||
winrt::Windows::Data::Json::JsonArray _jsonMessages; | ||
}; | ||
|
||
struct OpenAIResponse : public winrt::implements<OpenAIResponse, winrt::Microsoft::Terminal::Query::Extension::IResponse> | ||
{ | ||
OpenAIResponse(const winrt::hstring& message, const winrt::Microsoft::Terminal::Query::Extension::ErrorTypes errorType) : | ||
Message{ message }, | ||
ErrorType{ errorType } {} | ||
|
||
til::property<winrt::hstring> Message; | ||
til::property<winrt::Microsoft::Terminal::Query::Extension::ErrorTypes> ErrorType; | ||
}; | ||
} | ||
|
||
namespace winrt::Microsoft::Terminal::Query::Extension::factory_implementation | ||
{ | ||
BASIC_FACTORY(OpenAILLMProvider); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
Check failure Code scanning / check-spelling Check File Path
[AILLM](#security-tab) is not a recognized word. \(check-file-path\)
|
||
// Licensed under the MIT license. | ||
|
||
import "ILMProvider.idl"; | ||
|
||
namespace Microsoft.Terminal.Query.Extension | ||
{ | ||
runtimeclass OpenAILLMProvider : [default] ILMProvider | ||
{ | ||
OpenAILLMProvider(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,6 +129,13 @@ namespace winrt::TerminalApp::implementation | |
p.SetActionMap(_settings.ActionMap()); | ||
} | ||
|
||
// If the active LLMProvider changed, make sure we reinitialize the provider | ||
const auto newProviderType = _settings.GlobalSettings().AIInfo().ActiveProvider(); | ||
if (_lmProvider && (newProviderType != _currentProvider)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this depend on the existence of |
||
{ | ||
_createAndSetAuthenticationForLMProvider(newProviderType); | ||
} | ||
|
||
if (needRefreshUI) | ||
{ | ||
_RefreshUIForSettingsReload(); | ||
|
@@ -5601,13 +5608,15 @@ namespace winrt::TerminalApp::implementation | |
} | ||
} | ||
|
||
// since we only support one type of llmProvider for now, just instantiate that one (the AzureLLMProvider) | ||
// in the future, we would need to query the settings here for which LLMProvider to use | ||
_lmProvider = winrt::Microsoft::Terminal::Query::Extension::AzureLLMProvider(); | ||
_setAzureOpenAIAuth(); | ||
_azureOpenAISettingChangedRevoker = Microsoft::Terminal::Settings::Model::CascadiaSettings::AzureOpenAISettingChanged(winrt::auto_revoke, { this, &TerminalPage::_setAzureOpenAIAuth }); | ||
_extensionPalette = winrt::Microsoft::Terminal::Query::Extension::ExtensionPalette(); | ||
|
||
// create the correct lm provider | ||
_createAndSetAuthenticationForLMProvider(_settings.GlobalSettings().AIInfo().ActiveProvider()); | ||
|
||
// make sure we listen for auth changes | ||
_azureOpenAISettingChangedRevoker = Microsoft::Terminal::Settings::Model::AIConfig::AzureOpenAISettingChanged(winrt::auto_revoke, { this, &TerminalPage::_setAzureOpenAIAuth }); | ||
_openAISettingChangedRevoker = Microsoft::Terminal::Settings::Model::AIConfig::OpenAISettingChanged(winrt::auto_revoke, { this, &TerminalPage::_setOpenAIAuth }); | ||
|
||
_extensionPalette = winrt::Microsoft::Terminal::Query::Extension::ExtensionPalette(_lmProvider); | ||
_extensionPalette.RegisterPropertyChangedCallback(UIElement::VisibilityProperty(), [&](auto&&, auto&&) { | ||
if (_extensionPalette.Visibility() == Visibility::Collapsed) | ||
{ | ||
|
@@ -5665,18 +5674,59 @@ namespace winrt::TerminalApp::implementation | |
_extensionPalette.ActiveCommandline(L""); | ||
} | ||
}); | ||
|
||
ExtensionPresenter().Content(_extensionPalette); | ||
} | ||
|
||
void TerminalPage::_setAzureOpenAIAuth() | ||
void TerminalPage::_createAndSetAuthenticationForLMProvider(LLMProvider providerType) | ||
{ | ||
if (_lmProvider) | ||
if (!_lmProvider || (_currentProvider != providerType)) | ||
{ | ||
Windows::Foundation::Collections::ValueSet authValues{}; | ||
authValues.Insert(L"endpoint", Windows::Foundation::PropertyValue::CreateString(_settings.AIEndpoint())); | ||
authValues.Insert(L"key", Windows::Foundation::PropertyValue::CreateString(_settings.AIKey())); | ||
_lmProvider.SetAuthentication(authValues); | ||
// we don't have a provider or our current provider is the wrong one, create a new provider | ||
switch (providerType) | ||
{ | ||
case LLMProvider::AzureOpenAI: | ||
_currentProvider = LLMProvider::AzureOpenAI; | ||
_lmProvider = winrt::Microsoft::Terminal::Query::Extension::AzureLLMProvider(); | ||
break; | ||
case LLMProvider::OpenAI: | ||
_currentProvider = LLMProvider::OpenAI; | ||
_lmProvider = winrt::Microsoft::Terminal::Query::Extension::OpenAILLMProvider(); | ||
break; | ||
default: | ||
break; | ||
} | ||
} | ||
|
||
// we now have a provider of the correct type, update that | ||
Windows::Foundation::Collections::ValueSet authValues{}; | ||
const auto settingsAIInfo = _settings.GlobalSettings().AIInfo(); | ||
switch (providerType) | ||
{ | ||
case LLMProvider::AzureOpenAI: | ||
authValues.Insert(L"endpoint", Windows::Foundation::PropertyValue::CreateString(settingsAIInfo.AzureOpenAIEndpoint())); | ||
authValues.Insert(L"key", Windows::Foundation::PropertyValue::CreateString(settingsAIInfo.AzureOpenAIKey())); | ||
break; | ||
case LLMProvider::OpenAI: | ||
authValues.Insert(L"key", Windows::Foundation::PropertyValue::CreateString(settingsAIInfo.OpenAIKey())); | ||
break; | ||
default: | ||
break; | ||
} | ||
_lmProvider.SetAuthentication(authValues); | ||
|
||
if (_extensionPalette) | ||
{ | ||
_extensionPalette.SetProvider(_lmProvider); | ||
} | ||
} | ||
|
||
void TerminalPage::_setAzureOpenAIAuth() | ||
{ | ||
_createAndSetAuthenticationForLMProvider(LLMProvider::AzureOpenAI); | ||
} | ||
|
||
void TerminalPage::_setOpenAIAuth() | ||
{ | ||
_createAndSetAuthenticationForLMProvider(LLMProvider::OpenAI); | ||
} | ||
} |
Check failure
Code scanning / check-spelling
Unrecognized Spelling