Skip to content

Commit

Permalink
feat: add watsonx ai provider (#1163)
Browse files Browse the repository at this point in the history
Signed-off-by: JINSONG WANG <jswang@ibm.com>
  • Loading branch information
jinsongo committed Jul 1, 2024
1 parent ab534d1 commit ce63821
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ Unused:
> huggingface
> noopai
> googlevertexai
> watsonxai
```

For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/).
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/IBM/watsonx-go v1.0.0 // indirect
github.com/Microsoft/hcsshim v0.12.4 // indirect
github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9 // indirect
github.com/anchore/go-struct-converter v0.0.0-20230627203149-c72ef8859ca9 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,8 @@ github.com/Code-Hex/go-generics-cache v1.3.1 h1:i8rLwyhoyhaerr7JpjtYjJZUcCbWOdiY
github.com/Code-Hex/go-generics-cache v1.3.1/go.mod h1:qxcC9kRVrct9rHeiYpFWSoW1vxyillCVzX13KZG8dl4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/IBM/watsonx-go v1.0.0 h1:xG7xA2W9N0RsiztR26dwBI8/VxIX4wTBhdYmEis2Yl8=
github.com/IBM/watsonx-go v1.0.0/go.mod h1:8lzvpe/158JkrzvcoIcIj6OdNty5iC9co5nQHfkhRtM=
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
Expand Down
4 changes: 3 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var (
&HuggingfaceClient{},
&GoogleVertexAIClient{},
&OCIGenAIClient{},
&WatsonxAIClient{},
}
Backends = []string{
openAIClientName,
Expand All @@ -43,6 +44,7 @@ var (
huggingfaceAIClientName,
googleVertexAIClientName,
ociClientName,
watsonxAIClientName,
}
)

Expand Down Expand Up @@ -170,7 +172,7 @@ func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down
84 changes: 84 additions & 0 deletions pkg/ai/watsonxai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package ai

import (
"os"
"fmt"
"context"
"errors"

wx "github.com/IBM/watsonx-go/pkg/models"
)

const watsonxAIClientName = "watsonxai"

type WatsonxAIClient struct {
nopCloser

client *wx.Client
model string
temperature float32
topP float32
topK int32
maxNewTokens int
}

const (
modelMetallama = "ibm/granite-13b-chat-v2"
)

func (c *WatsonxAIClient) Configure(config IAIConfig) error {
if(config.GetModel() == "") {
c.model = config.GetModel()
} else {
c.model = modelMetallama
}
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
c.maxNewTokens = config.GetMaxTokens()

// WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
// WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)

if apiKey == "" {
return errors.New("No watsonx API key provided")
}
if projectID == "" {
return errors.New("No watsonx project ID provided")
}

client, err := wx.NewClient(
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
if err != nil {
return fmt.Errorf("Failed to create client for testing. Error: %v", err)
}
c.client = client

return nil
}

func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
result, err := c.client.GenerateText(
c.model,
prompt,
wx.WithTemperature((float64)(c.temperature)),
wx.WithTopP((float64)(c.topP)),
wx.WithTopK((uint)(c.topK)),
wx.WithMaxNewTokens((uint)(c.maxNewTokens)),
)
if err != nil {
return "", fmt.Errorf("Expected no error, but got an error: %v", err)
}
if result.Text == "" {
return "", errors.New("Expected a result, but got an empty string")
}

return result.Text, nil
}

func (c *WatsonxAIClient) GetName() string {
return watsonxAIClientName
}

0 comments on commit ce63821

Please sign in to comment.