Skip to content

Commit

Permalink
Add ability to configure custom OpenAI API endpoint for #186 (#194)
Browse files Browse the repository at this point in the history
* Add ability to configure custom OpenAI API endpoint for #186

* Ensure the AiCompletionEndpoint field is always initialized
  • Loading branch information
ddworken authored Mar 27, 2024
1 parent 46e9280 commit 21b401b
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 26 deletions.
2 changes: 1 addition & 1 deletion backend/server/internal/server/api_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
if numDevices == 0 {
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
}
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.ShellName, req.OsName, req.NumberCompletions)
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions)
if err != nil {
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
}
Expand Down
4 changes: 2 additions & 2 deletions client/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, num
}

func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" {
if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint {
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions)
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions)
return suggestions, err
}
}
Expand Down
12 changes: 11 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,17 @@ func testConfigGetSet(t *testing.T, tester shellTester) {
if out != "Command \"Exit Code\" Timestamp foobar \n" {
t.Fatalf("unexpected config-get output: %#v", out)
}

// For OpenAI endpoints
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
if out != "https://api.openai.com/v1/chat/completions\n" {
t.Fatalf("unexpected config-get output: %#v", out)
}
tester.RunInteractiveShell(t, `hishtory config-set ai-completion-endpoint https://example.com/foo/bar`)
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
if out != "https://example.com/foo/bar\n" {
t.Fatalf("unexpected config-get output: %#v", out)
}
}

func clearControlRSearchFromConfig(t testing.TB) {
Expand Down Expand Up @@ -2166,7 +2177,6 @@ func testTui_ai(t *testing.T) {
})
out = stripTuiCommandPrefix(t, out)
testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled")

}

func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) {
Expand Down
11 changes: 11 additions & 0 deletions client/cmd/configGet.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ var getColorScheme = &cobra.Command{
},
}

var getAiCompletionEndpoint = &cobra.Command{
Use: "ai-completion-endpoint",
Short: "The AI endpoint to use for AI completions",
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
fmt.Println(config.AiCompletionEndpoint)
},
}

func init() {
rootCmd.AddCommand(configGetCmd)
configGetCmd.AddCommand(getEnableControlRCmd)
Expand All @@ -159,4 +169,5 @@ func init() {
configGetCmd.AddCommand(getPresavingCmd)
configGetCmd.AddCommand(getColorScheme)
configGetCmd.AddCommand(getDefaultFilterCmd)
configGetCmd.AddCommand(getAiCompletionEndpoint)
}
13 changes: 13 additions & 0 deletions client/cmd/configSet.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,18 @@ func validateColor(color string) error {
return nil
}

var setAiCompletionEndpoint = &cobra.Command{
Use: "ai-completion-endpoint",
Short: "The AI endpoint to use for AI completions",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
config.AiCompletionEndpoint = args[0]
lib.CheckFatalError(hctx.SetConfig(config))
},
}

func init() {
rootCmd.AddCommand(configSetCmd)
configSetCmd.AddCommand(setEnableControlRCmd)
Expand All @@ -229,6 +241,7 @@ func init() {
configSetCmd.AddCommand(setPresavingCmd)
configSetCmd.AddCommand(setColorSchemeCmd)
configSetCmd.AddCommand(setDefaultFilterCommand)
configSetCmd.AddCommand(setAiCompletionEndpoint)
setColorSchemeCmd.AddCommand(setColorSchemeSelectedText)
setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground)
setColorSchemeCmd.AddCommand(setColorSchemeBorderColor)
Expand Down
5 changes: 5 additions & 0 deletions client/hctx/hctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ type ClientConfig struct {
ColorScheme ColorScheme `json:"color_scheme"`
// A default filter that will be applied to all search queries
DefaultFilter string `json:"default_filter"`
// The endpoint to use for AI suggestions
AiCompletionEndpoint string `json:"ai_completion_endpoint"`
}

type ColorScheme struct {
Expand Down Expand Up @@ -272,6 +274,9 @@ func GetConfig() (ClientConfig, error) {
if config.ColorScheme.BorderColor == "" {
config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor
}
if config.AiCompletionEndpoint == "" {
config.AiCompletionEndpoint = "https://api.openai.com/v1/chat/completions"
}
return config, nil
}

Expand Down
17 changes: 0 additions & 17 deletions scripts/aimain.go

This file was deleted.

12 changes: 8 additions & 4 deletions shared/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"golang.org/x/exp/slices"
)

const DefaultOpenAiEndpoint = "https://api.openai.com/v1/chat/completions"

type openAiRequest struct {
Model string `json:"model"`
Messages []openAiMessage `json:"messages"`
Expand Down Expand Up @@ -51,7 +53,7 @@ type TestOnlyOverrideAiSuggestionRequest struct {

var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)

func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil
}
Expand All @@ -63,7 +65,7 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
shellName = "bash"
}
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint {
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
}
client := &http.Client{}
Expand All @@ -82,12 +84,14 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
}
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
req, err := http.NewRequest("POST", apiEndpoint, bytes.NewBuffer(apiReqStr))
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req)
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion shared/ai/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestLiveOpenAiApi(t *testing.T) {
if os.Getenv("OPENAI_API_KEY") == "" {
t.Skip("Skipping test since OPENAI_API_KEY is not set")
}
results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", "bash", "Linux", 3)
results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3)
require.NoError(t, err)
resultsContainsLs := false
for _, result := range results {
Expand Down

0 comments on commit 21b401b

Please sign in to comment.