Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Cohere backend #563

Merged
merged 2 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,27 @@ k8sgpt analyze --explain --backend azureopenai



</details>

<details>
<summary>Cohere provider</summary>

<em>Prerequisites:</em> a Cohere API key is needed, please visit the [Cohere dashboard](https://dashboard.cohere.ai/api-keys) to create one.

To run k8sgpt, run `k8sgpt auth` with the `cohere` backend:

```
k8sgpt auth add --backend cohere --model command-nightly
```

Lastly, enter your Cohere API key, after the prompt.

Now you are ready to analyze with the Cohere backend:

```
k8sgpt analyze --explain --backend cohere
```

</details>

<details>
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ require (
buf.build/gen/go/k8sgpt-ai/k8sgpt/grpc/go v1.3.0-20230620082254-6f80f9533908.1
buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go v1.30.0-20230620082254-6f80f9533908.1
github.com/aws/aws-sdk-go v1.44.300
github.com/cohere-ai/cohere-go v0.2.0
)

require (
github.com/anchore/go-struct-converter v0.0.0-20221118182256-c68fdcfa2092 // indirect
github.com/cohere-ai/tokenizer v1.1.1 // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
)

Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,10 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20220314180256-7f1daf1720fc/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cohere-ai/cohere-go v0.2.0 h1:Gljkn8LTtsAPy79ks1AVmZH9Av4kuQuXEgzEJ/1Ea34=
github.com/cohere-ai/cohere-go v0.2.0/go.mod h1:DFcCu5rwro4wAlluIXY9l17NLGiVBGb2bRio46RXBm8=
github.com/cohere-ai/tokenizer v1.1.1 h1:wCtmCj07O82TMrIiA/CORhIlEYsvMMM8ey+sUdEapHc=
github.com/cohere-ai/tokenizer v1.1.1/go.mod h1:9MNFPd9j1fuiEK3ua2HSCUxxcrfGMlSqpa93livg/C0=
github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM=
github.com/containerd/containerd v1.7.0 h1:G/ZQr3gMZs6ZT0qPUZ15znx5QSdQdASW11nXTLTM2Pg=
github.com/containerd/containerd v1.7.0/go.mod h1:QfR7Efgb/6X2BDpTPJRvPTYDE9rsF0FsXX9J8sIs/sc=
Expand All @@ -518,6 +522,7 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2 h1:aBfCb7iqHmDEIp6fBvC/hQUddQfg+3qdYjwzaiP9Hnc=
github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E=
github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
github.com/docker/cli v23.0.5+incompatible h1:ufWmAOuD3Vmr7JP2G5K3cyuNC4YZWiAsuDEvFVVDafE=
github.com/docker/cli v23.0.5+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8=
Expand Down
116 changes: 116 additions & 0 deletions pkg/ai/cohere.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
Copyright 2023 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ai

import (
"context"
"encoding/base64"
"errors"
"fmt"
"strings"

"github.com/cohere-ai/cohere-go"
"github.com/fatih/color"

"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
)

type CohereClient struct {
client *cohere.Client
language string
model string
}

func (c *CohereClient) Configure(config IAIConfig, language string) error {
token := config.GetPassword()

client, err := cohere.CreateClient(token)
if err != nil {
return err
}

baseURL := config.GetBaseURL()
if baseURL != "" {
client.BaseURL = baseURL
}

if client == nil {
return errors.New("error creating Cohere client")
}
c.language = language
c.client = client
c.model = config.GetModel()
return nil
}

func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl string) (string, error) {
// Create a completion request
if len(promptTmpl) == 0 {
promptTmpl = PromptMap["default"]
}
resp, err := c.client.Generate(cohere.GenerateOptions{
Model: c.model,
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
MaxTokens: cohere.Uint(2048),
Temperature: cohere.Float64(0.75),
K: cohere.Int(0),
StopSequences: []string{},
ReturnLikelihoods: "NONE",
})
if err != nil {
return "", err
}
return resp.Generations[0].Text, nil
}

func (a *CohereClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ")
// Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)

if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
response, err := cache.Load(cacheKey)
if err != nil {
return "", err
}

if response != "" {
output, err := base64.StdEncoding.DecodeString(response)
if err != nil {
color.Red("error decoding cached data: %v", err)
return "", nil
}
return string(output), nil
}
}

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", nil
}

return response, nil
}

func (a *CohereClient) GetName() string {
return "cohere"
}
2 changes: 2 additions & 0 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ var (
&AzureAIClient{},
&LocalAIClient{},
&NoOpAIClient{},
&CohereClient{},
}
Backends = []string{
"openai",
"localai",
"azureopenai",
"noopai",
"cohere",
}
)

Expand Down