Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow OpenAI to be used with Terminal Chat #17540

Open
wants to merge 38 commits into
base: feature/llm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b44216b
got an idl
PankajBhojwani Jun 6, 2024
e78e4d0
works
PankajBhojwani Jun 7, 2024
ac83d76
allow
PankajBhojwani Jun 7, 2024
4b944e9
conflict
PankajBhojwani Jun 7, 2024
ef406ee
use id here too
PankajBhojwani Jun 7, 2024
4fb4ca4
have terminal page initialize the llmprovider
PankajBhojwani Jun 7, 2024
e5afbae
format
PankajBhojwani Jun 7, 2024
32b3d68
consts
PankajBhojwani Jun 8, 2024
1700a92
works on palette side
PankajBhojwani Jun 11, 2024
08dd951
open ai in settings, llmprovider enum
PankajBhojwani Jun 13, 2024
1934b30
aiconfig struct
PankajBhojwani Jun 20, 2024
96a489a
move more things to ai info
PankajBhojwani Jun 21, 2024
8e560e2
active provider buttons
PankajBhojwani Jul 9, 2024
f9e9326
rename to lmprovider
PankajBhojwani Jul 9, 2024
4133239
spelling conflict
PankajBhojwani Jul 9, 2024
e545ca2
conflicts
PankajBhojwani Jul 9, 2024
dfad8d9
spell
PankajBhojwani Jul 9, 2024
5139b88
array of accepted models
PankajBhojwani Jul 9, 2024
dfbdc75
cleanup this comment
PankajBhojwani Jul 9, 2024
1f305ab
first round of comments
PankajBhojwani Jul 15, 2024
242b964
combine setter/getter
PankajBhojwani Jul 15, 2024
0094db0
Merge branch 'dev/pabhoj/llm_provider_interface' of https://github.co…
PankajBhojwani Jul 15, 2024
6479c47
auth related functions
PankajBhojwani Jul 16, 2024
d392aab
dispatcher
PankajBhojwani Jul 17, 2024
8faf8b4
domain check
PankajBhojwani Jul 17, 2024
dd6b46d
hot reload works now
PankajBhojwani Jul 18, 2024
4f81775
conflict and updates to openAI provider
PankajBhojwani Jul 18, 2024
46ab508
password box
PankajBhojwani Jul 18, 2024
ac1b4a3
logic for updating providers based on settings changes
PankajBhojwani Jul 18, 2024
fc8e36d
use static strings for strings used more than once
PankajBhojwani Jul 19, 2024
fdee6c2
Merge branch 'feature/llm' of https://github.com/microsoft/terminal i…
PankajBhojwani Jul 19, 2024
fff97a4
merge
PankajBhojwani Jul 20, 2024
45ce94d
don't use &
PankajBhojwani Jul 22, 2024
c89a306
update with base changes
PankajBhojwani Jul 22, 2024
de88842
conflict
PankajBhojwani Sep 9, 2024
edf08a9
missed this
PankajBhojwani Sep 9, 2024
830e655
newline...
PankajBhojwani Sep 9, 2024
16d3035
conflict
PankajBhojwani Oct 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/actions/spelling/allow/allow.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
aci
AIIs
AILLM
allcolors
breadcrumb
breadcrumbs
Expand Down
9 changes: 7 additions & 2 deletions src/cascadia/QueryExtension/ExtensionPalette.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ const std::wregex azureOpenAIEndpointRegex{ LR"(^https.*openai\.azure\.com)" };

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
ExtensionPalette::ExtensionPalette(const Extension::ILMProvider lmProvider) :
_lmProvider{ lmProvider }
ExtensionPalette::ExtensionPalette()
{
InitializeComponent();

Expand Down Expand Up @@ -86,6 +85,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
});
}

void ExtensionPalette::SetProvider(const Extension::ILMProvider lmProvider)
{
_lmProvider = lmProvider;
_clearAndInitializeMessages(nullptr, nullptr);
}

void ExtensionPalette::IconPath(const winrt::hstring& iconPath)
{
// We don't need to store the path - just create the icon and set it,
Expand Down
3 changes: 2 additions & 1 deletion src/cascadia/QueryExtension/ExtensionPalette.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
struct ExtensionPalette : ExtensionPaletteT<ExtensionPalette>
{
ExtensionPalette(const Extension::ILMProvider lmProvider);
ExtensionPalette();
void SetProvider(const Extension::ILMProvider lmProvider);

// We don't use the winrt_property macro here because we just need the setter
void IconPath(const winrt::hstring& iconPath);
Expand Down
3 changes: 2 additions & 1 deletion src/cascadia/QueryExtension/ExtensionPalette.idl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace Microsoft.Terminal.Query.Extension

[default_interface] runtimeclass ExtensionPalette : Windows.UI.Xaml.Controls.UserControl, Windows.UI.Xaml.Data.INotifyPropertyChanged
{
ExtensionPalette(ILMProvider lmProvider);
ExtensionPalette();
void SetProvider(ILMProvider lmProvider);

String ControlName { get; };
String QueryBoxPlaceholderText { get; };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
<ClInclude Include="AzureLLMProvider.h">
<DependentUpon>AzureLLMProvider.idl</DependentUpon>
</ClInclude>
<ClInclude Include="OpenAILLMProvider.h">

Check failure

Code scanning / check-spelling

Unrecognized Spelling

[AILLM](#security-tab) is not a recognized word. \(unrecognized-spelling\)
<DependentUpon>OpenAILLMProvider.idl</DependentUpon>

Check failure

Code scanning / check-spelling

Unrecognized Spelling

[AILLM](#security-tab) is not a recognized word. \(unrecognized-spelling\)
</ClInclude>
</ItemGroup>
<!-- ========================= XAML files ======================== -->
<ItemGroup>
Expand All @@ -80,6 +83,9 @@
<ClCompile Include="AzureLLMProvider.cpp">
<DependentUpon>AzureLLMProvider.idl</DependentUpon>
</ClCompile>
<ClCompile Include="OpenAILLMProvider.cpp">

Check failure

Code scanning / check-spelling

Unrecognized Spelling

[AILLM](#security-tab) is not a recognized word. \(unrecognized-spelling\)
<DependentUpon>OpenAILLMProvider.idl</DependentUpon>

Check failure

Code scanning / check-spelling

Unrecognized Spelling

[AILLM](#security-tab) is not a recognized word. \(unrecognized-spelling\)
</ClCompile>
</ItemGroup>
<!-- ========================= idl Files ======================== -->
<ItemGroup>
Expand All @@ -96,6 +102,9 @@
<Midl Include="AzureLLMProvider.idl">
<SubType>Code</SubType>
</Midl>
<Midl Include="OpenAILLMProvider.idl">

Check failure

Code scanning / check-spelling

Unrecognized Spelling

[AILLM](#security-tab) is not a recognized word. \(unrecognized-spelling\)
<SubType>Code</SubType>
</Midl>
</ItemGroup>
<!-- ========================= Misc Files ======================== -->
<ItemGroup>
Expand Down
125 changes: 125 additions & 0 deletions src/cascadia/QueryExtension/OpenAILLMProvider.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation.

Check failure

Code scanning / check-spelling

Check File Path

[AILLM](#security-tab) is not a recognized word. \(check-file-path\)
// Licensed under the MIT license.

#include "pch.h"
#include "OpenAILLMProvider.h"
#include "../../types/inc/utils.hpp"
#include "LibraryResources.h"

#include "OpenAILLMProvider.g.cpp"

using namespace winrt::Windows::Foundation;
using namespace winrt::Windows::Foundation::Collections;
using namespace winrt::Windows::UI::Core;
using namespace winrt::Windows::UI::Xaml;
using namespace winrt::Windows::UI::Xaml::Controls;
using namespace winrt::Windows::System;
namespace WWH = ::winrt::Windows::Web::Http;
namespace WSS = ::winrt::Windows::Storage::Streams;
namespace WDJ = ::winrt::Windows::Data::Json;
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved

static constexpr std::wstring_view applicationJson{ L"application/json" };
static constexpr std::wstring_view acceptedModel{ L"gpt-3.5-turbo" };
Fixed Show fixed Hide fixed
static constexpr std::wstring_view openAIEndpoint{ L"https://api.openai.com/v1/chat/completions" };

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
void OpenAILLMProvider::SetAuthentication(const Windows::Foundation::Collections::ValueSet& authValues)
{
_AIKey = unbox_value_or<hstring>(authValues.TryLookup(L"key").try_as<IPropertyValue>(), L"");
_httpClient = winrt::Windows::Web::Http::HttpClient{};
_httpClient.DefaultRequestHeaders().Accept().TryParseAdd(applicationJson);
_httpClient.DefaultRequestHeaders().Authorization(WWH::Headers::HttpCredentialsHeaderValue{ L"Bearer", _AIKey });
}

void OpenAILLMProvider::ClearMessageHistory()
{
_jsonMessages.Clear();
}

void OpenAILLMProvider::SetSystemPrompt(const winrt::hstring& systemPrompt)
{
WDJ::JsonObject systemMessageObject;
winrt::hstring systemMessageContent{ systemPrompt };
systemMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"system"));
systemMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(systemMessageContent));
_jsonMessages.Append(systemMessageObject);
}

void OpenAILLMProvider::SetContext(const Extension::IContext context)
{
_context = context;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we'd ideally take the argument by-copy and then move it into place. This lets the caller move the instance into the argument which you then move into the member.
(A very minor nit.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and: Moving only works if the argument is not const. If it's const it'll not fail to compile, because... C++ reasons. Instead, it'll silently just create a copy. Bu With high enough compiler warnings it'll warn you about it though, but idk if it's enabled for this project.

}

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> OpenAILLMProvider::GetResponseAsync(const winrt::hstring userPrompt)
{
// Use the ErrorTypes enum to flag whether the response the user receives is an error message
// we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event)
ErrorTypes errorType{ ErrorTypes::None };
hstring message{};

// Make sure we are on the background thread for the http request
co_await winrt::resume_background();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs the strong-this pointer for safety. You can use the weak/strong handling that we do elsewhere or hold onto a strong pointer the entire time. That's pretty much entirely up to you (the latter is probably better though).


WWH::HttpRequestMessage request{ WWH::HttpMethod::Post(), Uri{ openAIEndpoint } };
request.Headers().Accept().TryParseAdd(applicationJson);

WDJ::JsonObject jsonContent;
WDJ::JsonObject messageObject;

winrt::hstring engineeredPrompt{ userPrompt };
if (_context && !_context.ActiveCommandline().empty())
{
engineeredPrompt = userPrompt + L". The shell I am running is " + _context.ActiveCommandline();
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
}
messageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"user"));
messageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(engineeredPrompt));
_jsonMessages.Append(messageObject);
jsonContent.SetNamedValue(L"model", WDJ::JsonValue::CreateStringValue(acceptedModel));
jsonContent.SetNamedValue(L"messages", _jsonMessages);
jsonContent.SetNamedValue(L"temperature", WDJ::JsonValue::CreateNumberValue(0));
const auto stringContent = jsonContent.ToString();
WWH::HttpStringContent requestContent{
stringContent,
WSS::UnicodeEncoding::Utf8,
applicationJson
};

request.Content(requestContent);

// Send the request
try
{
const auto response = _httpClient.SendRequestAsync(request).get();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

co_await?

// Parse out the suggestion from the response
const auto string{ response.Content().ReadAsStringAsync().get() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
errorType = ErrorTypes::FromProvider;
}
else
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
}
}
catch (...)
{
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}

// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
WDJ::JsonObject responseMessageObject;
responseMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"assistant"));
responseMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(message));
_jsonMessages.Append(responseMessageObject);

co_return winrt::make<OpenAIResponse>(message, errorType);
}
}
46 changes: 46 additions & 0 deletions src/cascadia/QueryExtension/OpenAILLMProvider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation.

Check failure

Code scanning / check-spelling

Check File Path

[AILLM](#security-tab) is not a recognized word. \(check-file-path\)
// Licensed under the MIT license.

#pragma once

#include "OpenAILLMProvider.g.h"

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
struct OpenAILLMProvider : OpenAILLMProviderT<OpenAILLMProvider>
{
OpenAILLMProvider() = default;

void ClearMessageHistory();
void SetSystemPrompt(const winrt::hstring& systemPrompt);
void SetContext(const Extension::IContext context);

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> GetResponseAsync(const winrt::hstring userPrompt);

void SetAuthentication(const Windows::Foundation::Collections::ValueSet& authValues);
TYPED_EVENT(AuthChanged, winrt::Microsoft::Terminal::Query::Extension::ILMProvider, Windows::Foundation::Collections::ValueSet);

private:
winrt::hstring _AIKey;
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr };

Extension::IContext _context;

winrt::Windows::Data::Json::JsonArray _jsonMessages;
};

struct OpenAIResponse : public winrt::implements<OpenAIResponse, winrt::Microsoft::Terminal::Query::Extension::IResponse>
{
OpenAIResponse(const winrt::hstring& message, const winrt::Microsoft::Terminal::Query::Extension::ErrorTypes errorType) :
Message{ message },
ErrorType{ errorType } {}

til::property<winrt::hstring> Message;
til::property<winrt::Microsoft::Terminal::Query::Extension::ErrorTypes> ErrorType;
};
}

namespace winrt::Microsoft::Terminal::Query::Extension::factory_implementation
{
BASIC_FACTORY(OpenAILLMProvider);
}
12 changes: 12 additions & 0 deletions src/cascadia/QueryExtension/OpenAILLMProvider.idl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation.

Check failure

Code scanning / check-spelling

Check File Path

[AILLM](#security-tab) is not a recognized word. \(check-file-path\)
// Licensed under the MIT license.

import "ILMProvider.idl";

namespace Microsoft.Terminal.Query.Extension
{
runtimeclass OpenAILLMProvider : [default] ILMProvider
{
OpenAILLMProvider();
}
}
76 changes: 63 additions & 13 deletions src/cascadia/TerminalApp/TerminalPage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ namespace winrt::TerminalApp::implementation
p.SetActionMap(_settings.ActionMap());
}

// If the active LLMProvider changed, make sure we reinitialize the provider
const auto newProviderType = _settings.GlobalSettings().AIInfo().ActiveProvider();
if (_lmProvider && (newProviderType != _currentProvider))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this depend on the existence of _lmProvider?

{
_createAndSetAuthenticationForLMProvider(newProviderType);
}

if (needRefreshUI)
{
_RefreshUIForSettingsReload();
Expand Down Expand Up @@ -5601,13 +5608,15 @@ namespace winrt::TerminalApp::implementation
}
}

// since we only support one type of llmProvider for now, just instantiate that one (the AzureLLMProvider)
// in the future, we would need to query the settings here for which LLMProvider to use
_lmProvider = winrt::Microsoft::Terminal::Query::Extension::AzureLLMProvider();
_setAzureOpenAIAuth();
_azureOpenAISettingChangedRevoker = Microsoft::Terminal::Settings::Model::CascadiaSettings::AzureOpenAISettingChanged(winrt::auto_revoke, { this, &TerminalPage::_setAzureOpenAIAuth });
_extensionPalette = winrt::Microsoft::Terminal::Query::Extension::ExtensionPalette();

// create the correct lm provider
_createAndSetAuthenticationForLMProvider(_settings.GlobalSettings().AIInfo().ActiveProvider());

// make sure we listen for auth changes
_azureOpenAISettingChangedRevoker = Microsoft::Terminal::Settings::Model::AIConfig::AzureOpenAISettingChanged(winrt::auto_revoke, { this, &TerminalPage::_setAzureOpenAIAuth });
_openAISettingChangedRevoker = Microsoft::Terminal::Settings::Model::AIConfig::OpenAISettingChanged(winrt::auto_revoke, { this, &TerminalPage::_setOpenAIAuth });

_extensionPalette = winrt::Microsoft::Terminal::Query::Extension::ExtensionPalette(_lmProvider);
_extensionPalette.RegisterPropertyChangedCallback(UIElement::VisibilityProperty(), [&](auto&&, auto&&) {
if (_extensionPalette.Visibility() == Visibility::Collapsed)
{
Expand Down Expand Up @@ -5665,18 +5674,59 @@ namespace winrt::TerminalApp::implementation
_extensionPalette.ActiveCommandline(L"");
}
});

ExtensionPresenter().Content(_extensionPalette);
}

void TerminalPage::_setAzureOpenAIAuth()
void TerminalPage::_createAndSetAuthenticationForLMProvider(LLMProvider providerType)
{
if (_lmProvider)
if (!_lmProvider || (_currentProvider != providerType))
{
Windows::Foundation::Collections::ValueSet authValues{};
authValues.Insert(L"endpoint", Windows::Foundation::PropertyValue::CreateString(_settings.AIEndpoint()));
authValues.Insert(L"key", Windows::Foundation::PropertyValue::CreateString(_settings.AIKey()));
_lmProvider.SetAuthentication(authValues);
// we don't have a provider or our current provider is the wrong one, create a new provider
switch (providerType)
{
case LLMProvider::AzureOpenAI:
_currentProvider = LLMProvider::AzureOpenAI;
_lmProvider = winrt::Microsoft::Terminal::Query::Extension::AzureLLMProvider();
break;
case LLMProvider::OpenAI:
_currentProvider = LLMProvider::OpenAI;
_lmProvider = winrt::Microsoft::Terminal::Query::Extension::OpenAILLMProvider();
break;
default:
break;
}
}

// we now have a provider of the correct type, update that
Windows::Foundation::Collections::ValueSet authValues{};
const auto settingsAIInfo = _settings.GlobalSettings().AIInfo();
switch (providerType)
{
case LLMProvider::AzureOpenAI:
authValues.Insert(L"endpoint", Windows::Foundation::PropertyValue::CreateString(settingsAIInfo.AzureOpenAIEndpoint()));
authValues.Insert(L"key", Windows::Foundation::PropertyValue::CreateString(settingsAIInfo.AzureOpenAIKey()));
break;
case LLMProvider::OpenAI:
authValues.Insert(L"key", Windows::Foundation::PropertyValue::CreateString(settingsAIInfo.OpenAIKey()));
break;
default:
break;
}
_lmProvider.SetAuthentication(authValues);

if (_extensionPalette)
{
_extensionPalette.SetProvider(_lmProvider);
}
}

void TerminalPage::_setAzureOpenAIAuth()
{
_createAndSetAuthenticationForLMProvider(LLMProvider::AzureOpenAI);
}

void TerminalPage::_setOpenAIAuth()
{
_createAndSetAuthenticationForLMProvider(LLMProvider::OpenAI);
}
}
6 changes: 5 additions & 1 deletion src/cascadia/TerminalApp/TerminalPage.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,12 @@ namespace winrt::TerminalApp::implementation
winrt::Windows::UI::Xaml::FrameworkElement::Loaded_revoker _extensionPaletteLoadedRevoker;
Microsoft::Terminal::Settings::Model::CascadiaSettings _settings{ nullptr };

winrt::Microsoft::Terminal::Settings::Model::CascadiaSettings::AzureOpenAISettingChanged_revoker _azureOpenAISettingChangedRevoker;
winrt::Microsoft::Terminal::Settings::Model::LLMProvider _currentProvider;
winrt::Microsoft::Terminal::Settings::Model::AIConfig::AzureOpenAISettingChanged_revoker _azureOpenAISettingChangedRevoker;
void _setAzureOpenAIAuth();
winrt::Microsoft::Terminal::Settings::Model::AIConfig::OpenAISettingChanged_revoker _openAISettingChangedRevoker;
void _setOpenAIAuth();
void _createAndSetAuthenticationForLMProvider(winrt::Microsoft::Terminal::Settings::Model::LLMProvider providerType);

Windows::Foundation::Collections::IObservableVector<TerminalApp::TabBase> _tabs;
Windows::Foundation::Collections::IObservableVector<TerminalApp::TabBase> _mruTabs;
Expand Down
Loading
Loading