Skip to content

Commit

Permalink
feat(cohere): update cohere client implementation to v2 and to use ch…
Browse files Browse the repository at this point in the history
…at endpoint
  • Loading branch information
miguelvr committed Apr 11, 2024
1 parent 6df0169 commit 96b05eb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 22 deletions.
4 changes: 1 addition & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.1
github.com/aws/aws-sdk-go v1.51.14
github.com/cohere-ai/cohere-go v0.2.0
github.com/cohere-ai/cohere-go/v2 v2.7.1
github.com/google/generative-ai-go v0.10.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1
github.com/hupe1980/go-huggingface v0.0.15
Expand Down Expand Up @@ -62,11 +62,9 @@ require (
github.com/Microsoft/hcsshim v0.11.4 // indirect
github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9 // indirect
github.com/anchore/go-struct-converter v0.0.0-20230627203149-c72ef8859ca9 // indirect
github.com/cohere-ai/tokenizer v1.1.1 // indirect
github.com/containerd/console v1.0.3 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.5.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/evanphx/json-patch/v5 v5.7.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-kit/log v0.2.1 // indirect
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1378,10 +1378,8 @@ github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa h1:jQCWAUqqlij9Pgj2i/PB79y4KOPYVyFYdROxgaCwdTQ=
github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa/go.mod h1:x/1Gn8zydmfq8dk6e9PdstVsDgu9RuyIIJqAaF//0IM=
github.com/cohere-ai/cohere-go v0.2.0 h1:Gljkn8LTtsAPy79ks1AVmZH9Av4kuQuXEgzEJ/1Ea34=
github.com/cohere-ai/cohere-go v0.2.0/go.mod h1:DFcCu5rwro4wAlluIXY9l17NLGiVBGb2bRio46RXBm8=
github.com/cohere-ai/tokenizer v1.1.1 h1:wCtmCj07O82TMrIiA/CORhIlEYsvMMM8ey+sUdEapHc=
github.com/cohere-ai/tokenizer v1.1.1/go.mod h1:9MNFPd9j1fuiEK3ua2HSCUxxcrfGMlSqpa93livg/C0=
github.com/cohere-ai/cohere-go/v2 v2.7.1 h1:w2osOaSXZLGmmLuIAIR3hepaJENZSLrpU9VSh8rPJ2s=
github.com/cohere-ai/cohere-go/v2 v2.7.1/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4=
github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM=
github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw=
github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw=
Expand Down
37 changes: 22 additions & 15 deletions pkg/ai/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (
"context"
"errors"

"github.com/cohere-ai/cohere-go"
api "github.com/cohere-ai/cohere-go/v2"
cohere "github.com/cohere-ai/cohere-go/v2/client"
"github.com/cohere-ai/cohere-go/v2/option"
)

const cohereAIClientName = "cohere"
Expand All @@ -28,45 +30,50 @@ type CohereClient struct {
client *cohere.Client
model string
temperature float32
topP float32
maxTokens int
}

func (c *CohereClient) Configure(config IAIConfig) error {
token := config.GetPassword()

client, err := cohere.CreateClient(token)
if err != nil {
return err
opts := []option.RequestOption{
cohere.WithToken(token),
}

baseURL := config.GetBaseURL()
if baseURL != "" {
client.BaseURL = baseURL
opts = append(opts, cohere.WithBaseURL(baseURL))
}

client := cohere.NewClient(opts...)
if client == nil {
return errors.New("error creating Cohere client")
}

c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
c.maxTokens = config.GetMaxTokens()

return nil
}

func (c *CohereClient) GetCompletion(_ context.Context, prompt string) (string, error) {
func (c *CohereClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
// Create a completion request
resp, err := c.client.Generate(cohere.GenerateOptions{
Model: c.model,
Prompt: prompt,
MaxTokens: cohere.Uint(2048),
Temperature: cohere.Float64(float64(c.temperature)),
K: cohere.Int(0),
StopSequences: []string{},
ReturnLikelihoods: "NONE",
response, err := c.client.Chat(ctx, &api.ChatRequest{
Message: prompt,
Model: &c.model,
K: api.Int(0),
Preamble: api.String(""),
Temperature: api.Float64(float64(c.temperature)),
RawPrompting: api.Bool(false),
MaxTokens: api.Int(c.maxTokens),
})
if err != nil {
return "", err
}
return resp.Generations[0].Text, nil
return response.Text, nil
}

func (c *CohereClient) GetName() string {
Expand Down

0 comments on commit 96b05eb

Please sign in to comment.