Skip to content

Commit

Permalink
feat: add custom http headers to openai related api backends (#1174)
Browse files Browse the repository at this point in the history
* feat: add custom http headers to openai related api backends

Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>

* ci: add custom headers test

Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>

* add error handling

Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>

* chore(deps): update docker/setup-buildx-action digest to 4fd8129 (#1173)

Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>

* fix(deps): update module buf.build/gen/go/k8sgpt-ai/k8sgpt/grpc-ecosystem/gateway/v2 to v2.20.0-20240406062209-1cc152efbf5c.1 (#1147)

Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>

* chore(deps): update anchore/sbom-action action to v0.16.0 (#1146)

Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>

* Update README.md

Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>

---------

Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com>
Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
  • Loading branch information
3 people committed Jul 10, 2024
1 parent fef8539 commit 02e754e
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 26 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ _Analysis with serve mode_
```
grpcurl -plaintext -d '{"namespace": "k8sgpt", "explain": false}' localhost:8080 schema.v1.ServerService/Analyze
```

_Analysis with custom headers_

```
k8sgpt analyze --explain --custom-headers CustomHeaderKey:CustomHeaderValue
```
</details>

## LLM AI Backends
Expand Down
5 changes: 4 additions & 1 deletion cmd/analyze/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var (
withDoc bool
interactiveMode bool
customAnalysis bool
customHeaders []string
)

// AnalyzeCmd represents the problems command
Expand All @@ -59,6 +60,7 @@ var AnalyzeCmd = &cobra.Command{
maxConcurrency,
withDoc,
interactiveMode,
customHeaders,
)

if err != nil {
Expand Down Expand Up @@ -138,5 +140,6 @@ func init() {
AnalyzeCmd.Flags().BoolVarP(&interactiveMode, "interactive", "i", false, "Enable interactive mode that allows further conversation with LLM about the problem. Works only with --explain flag")
// custom analysis flag
AnalyzeCmd.Flags().BoolVarP(&customAnalysis, "custom-analysis", "z", false, "Enable custom analyzers")

// add custom headers flag
AnalyzeCmd.Flags().StringSliceVarP(&customHeaders, "custom-headers", "r", []string{}, "Custom Headers, <key>:<value> (e.g CustomHeaderKey:CustomHeaderValue AnotherHeader:AnotherValue)")
}
39 changes: 23 additions & 16 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package ai

import (
"context"
"net/http"
)

var (
Expand Down Expand Up @@ -83,6 +84,7 @@ type IAIConfig interface {
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
GetCustomHeaders() []http.Header
}

func NewClient(provider string) IAI {
Expand All @@ -101,22 +103,23 @@ type AIConfiguration struct {
}

type AIProvider struct {
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders"`
}

func (p *AIProvider) GetBaseURL() string {
Expand Down Expand Up @@ -174,6 +177,10 @@ func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}

func (p *AIProvider) GetCustomHeaders() []http.Header {
return p.CustomHeaders
}

var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}

func NeedPassword(backend string) bool {
Expand Down
39 changes: 32 additions & 7 deletions pkg/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,27 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
defaultConfig.BaseURL = baseURL
}

transport := &http.Transport{}
if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}

defaultConfig.HTTPClient = &http.Client{
Transport: transport,
}
transport.Proxy = http.ProxyURL(proxyUrl)
}

if orgId != "" {
defaultConfig.OrgID = orgId
}

customHeaders := config.GetCustomHeaders()
defaultConfig.HTTPClient = &http.Client{
Transport: &OpenAIHeaderTransport{
Origin: transport,
Headers: customHeaders,
},
}

client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating OpenAI client")
Expand Down Expand Up @@ -106,3 +109,25 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string
func (c *OpenAIClient) GetName() string {
return openAIClientName
}

// OpenAIHeaderTransport is an http.RoundTripper that adds the given headers to each request.
type OpenAIHeaderTransport struct {
Origin http.RoundTripper
Headers []http.Header
}

// RoundTrip implements the http.RoundTripper interface.
func (t *OpenAIHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original request
clonedReq := req.Clone(req.Context())
for _, header := range t.Headers {
for key, values := range header {
// Possible values per header: RFC 2616
for _, value := range values {
clonedReq.Header.Add(key, value)
}
}
}

return t.Origin.RoundTrip(clonedReq)
}
106 changes: 106 additions & 0 deletions pkg/ai/openai_header_transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package ai

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

// Mock configuration
type mockConfig struct {
baseURL string
}

func (m *mockConfig) GetPassword() string {
return ""
}

func (m *mockConfig) GetOrganizationId() string {
return ""
}

func (m *mockConfig) GetProxyEndpoint() string {
return ""
}

func (m *mockConfig) GetBaseURL() string {
return m.baseURL
}

func (m *mockConfig) GetCustomHeaders() []http.Header {
return []http.Header{
{"X-Custom-Header-1": []string{"Value1"}},
{"X-Custom-Header-2": []string{"Value2"}},
{"X-Custom-Header-2": []string{"Value3"}}, // Testing multiple values for the same header
}
}

func (m *mockConfig) GetModel() string {
return ""
}

func (m *mockConfig) GetTemperature() float32 {
return 0.0
}

func (m *mockConfig) GetTopP() float32 {
return 0.0
}
func (m *mockConfig) GetCompartmentId() string {
return ""
}

func (m *mockConfig) GetTopK() int32 {
return 0.0
}

func (m *mockConfig) GetMaxTokens() int {
return 0
}

func (m *mockConfig) GetEndpointName() string {
return ""
}
func (m *mockConfig) GetEngine() string {
return ""
}

func (m *mockConfig) GetProviderId() string {
return ""
}

func (m *mockConfig) GetProviderRegion() string {
return ""
}

func TestOpenAIClient_CustomHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "Value1", r.Header.Get("X-Custom-Header-1"))
assert.ElementsMatch(t, []string{"Value2", "Value3"}, r.Header["X-Custom-Header-2"])
w.WriteHeader(http.StatusOK)
// Mock response for openai completion
mockResponse := `{"choices": [{"message": {"content": "test"}}]}`
n, err := w.Write([]byte(mockResponse))
if err != nil {
t.Fatalf("error writing response: %v", err)
}
if n != len(mockResponse) {
t.Fatalf("expected to write %d bytes but wrote %d bytes", len(mockResponse), n)
}
}))
defer server.Close()

config := &mockConfig{baseURL: server.URL}

client := &OpenAIClient{}
err := client.Configure(config)
assert.NoError(t, err)

// Make a completion request to trigger the headers
ctx := context.Background()
_, err = client.GetCompletion(ctx, "foo prompt")
assert.NoError(t, err)
}
3 changes: 3 additions & 0 deletions pkg/analysis/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func NewAnalysis(
maxConcurrency int,
withDoc bool,
interactiveMode bool,
httpHeaders []string,
) (*Analysis, error) {
// Get kubernetes client from viper.
kubecontext := viper.GetString("kubecontext")
Expand Down Expand Up @@ -146,6 +147,8 @@ func NewAnalysis(
}

aiClient := ai.NewClient(aiProvider.Name)
customHeaders := util.NewHeaders(httpHeaders)
aiProvider.CustomHeaders = customHeaders
if err := aiClient.Configure(&aiProvider); err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/server/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ func (h *handler) Analyze(ctx context.Context, i *schemav1.AnalyzeRequest) (
i.Nocache,
i.Explain,
int(i.MaxConcurrency),
false, // Kubernetes Doc disabled in server mode
false, // Interactive mode disabled in server mode
false, // Kubernetes Doc disabled in server mode
false, // Interactive mode disabled in server mode
[]string{}, //TODO: add custom http headers in server mode
)
config.Context = ctx // Replace context for correct timeouts.
if err != nil {
Expand Down
34 changes: 34 additions & 0 deletions pkg/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net/http"
"os"
"regexp"
"strings"
Expand Down Expand Up @@ -261,3 +262,36 @@ func FetchLatestEvent(ctx context.Context, kubernetesClient *kubernetes.Client,
}
return latestEvent, nil
}

// NewHeaders parses a slice of strings in the format "key:value" into []http.Header
// It handles headers with the same key by appending values
func NewHeaders(customHeaders []string) []http.Header {
headers := make(map[string][]string)

for _, header := range customHeaders {
vals := strings.SplitN(header, ":", 2)
if len(vals) != 2 {
//TODO: Handle error instead of ignoring it
continue
}
key := strings.TrimSpace(vals[0])
value := strings.TrimSpace(vals[1])

if _, ok := headers[key]; !ok {
headers[key] = []string{}
}
headers[key] = append(headers[key], value)
}

// Convert map to []http.Header format
var result []http.Header
for key, values := range headers {
header := make(http.Header)
for _, value := range values {
header.Add(key, value)
}
result = append(result, header)
}

return result
}

0 comments on commit 02e754e

Please sign in to comment.