Skip to content

Commit

Permalink
feat: added Google GenAI client; simplified IAI/clients API surface. (#…
Browse files Browse the repository at this point in the history
…829)

* refactor: Simplified IAI; made caching and processing consisent.


Signed-off-by: bwplotka <bwplotka@gmail.com>

* feat: Added Google AI API e.g. for Gemini models.

Signed-off-by: bwplotka <bwplotka@gmail.com>

---------

Signed-off-by: bwplotka <bwplotka@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
Co-authored-by: Thomas Schuetz <38893055+thschue@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 5, 2024
1 parent e78ff05 commit e7d4149
Show file tree
Hide file tree
Showing 14 changed files with 240 additions and 309 deletions.
1 change: 1 addition & 0 deletions cmd/analyze/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ var AnalyzeCmd = &cobra.Command{
color.Red("Error: %v", err)
os.Exit(1)
}
defer config.Close()

config.RunAnalysis()

Expand Down
6 changes: 3 additions & 3 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,15 @@ func init() {
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for endpointName
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, e.g. `endpoint-xxxxxxxxxxxx` (only for amazonbedrock, amazonsagemaker backends)")
// add flag for topP
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
// max tokens
addCmd.Flags().IntVarP(&maxTokens, "maxtokens", "l", 2048, "Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
//add flag for amazonbedrock region name
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name")
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)")
}
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1
github.com/aws/aws-sdk-go v1.49.15
github.com/cohere-ai/cohere-go v0.2.0
github.com/google/generative-ai-go v0.5.0
github.com/olekukonko/tablewriter v0.0.5
google.golang.org/api v0.155.0
sigs.k8s.io/controller-runtime v0.16.3
Expand All @@ -39,9 +40,11 @@ require (

require (
cloud.google.com/go v0.110.10 // indirect
cloud.google.com/go/ai v0.3.0 // indirect
cloud.google.com/go/compute v1.23.3 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.5 // indirect
cloud.google.com/go/longrunning v0.5.4 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ cloud.google.com/go/accesscontextmanager v1.3.0/go.mod h1:TgCBehyr5gNMz7ZaH9xubp
cloud.google.com/go/accesscontextmanager v1.4.0/go.mod h1:/Kjh7BBu/Gh83sv+K60vN9QE5NJcd80sU33vIe2IFPE=
cloud.google.com/go/accesscontextmanager v1.6.0/go.mod h1:8XCvZWfYw3K/ji0iVnp+6pu7huxoQTLmxAbVjbloTtM=
cloud.google.com/go/accesscontextmanager v1.7.0/go.mod h1:CEGLewx8dwa33aDAZQujl7Dx+uYhS0eay198wB/VumQ=
cloud.google.com/go/ai v0.3.0 h1:M617N0brv+XFch2KToZUhv6ggzgFZMUnmDkNQjW2pYg=
cloud.google.com/go/ai v0.3.0/go.mod h1:dTuQIBA8Kljuas5z1WNot1QZOl476A9TsFqEi6pzJlI=
cloud.google.com/go/aiplatform v1.22.0/go.mod h1:ig5Nct50bZlzV6NvKaTwmplLLddFx0YReh9WfTO5jKw=
cloud.google.com/go/aiplatform v1.24.0/go.mod h1:67UUvRBKG6GTayHKV8DBv2RtR1t93YRu5B1P3x99mYY=
cloud.google.com/go/aiplatform v1.27.0/go.mod h1:Bvxqtl40l0WImSb04d0hXFU7gDOiq9jQmorivIiWcKg=
Expand Down Expand Up @@ -351,6 +353,8 @@ cloud.google.com/go/logging v1.7.0/go.mod h1:3xjP2CjkM3ZkO73aj4ASA5wRPGGCRrPIAeN
cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE=
cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc=
cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo=
cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg=
cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI=
cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE=
cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM=
cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA=
Expand Down Expand Up @@ -935,6 +939,8 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/generative-ai-go v0.5.0 h1:PfzPuSGdsmcSyPG7RIoijcKWZ7/x2kvgyNryvmXMUmA=
github.com/google/generative-ai-go v0.5.0/go.mod h1:8fXQk4w+eyTzFokGGJrBFL0/xwXqm3QNhTqOWyX11zs=
github.com/google/gnostic v0.7.0 h1:d7EpuFp8vVdML+y0JJJYiKeOLjKTdH/GvVkLOBWqJpw=
github.com/google/gnostic v0.7.0/go.mod h1:IAcUyMl6vtC95f60EZ8oXyqTsOersP6HbwjeG7EyDPM=
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 h1:0VpGH+cDhbDtdcweoyCVsF3fhN8kejK6rFe/2FFX2nU=
Expand Down
55 changes: 5 additions & 50 deletions pkg/ai/amazonbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@ package ai

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"strings"

"github.com/fatih/color"

"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -19,8 +12,9 @@ import (

// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service.
type AmazonBedRockClient struct {
nopCloser

client *bedrockruntime.BedrockRuntime
language string
model string
temperature float32
}
Expand Down Expand Up @@ -91,8 +85,8 @@ func GetRegionOrDefault(region string) string {
return BEDROCK_DEFAULT_REGION
}

// Configure configures the AmazonBedRockClient with the provided configuration and language.
func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error {
// Configure configures the AmazonBedRockClient with the provided configuration.
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {

// Create a new AWS session
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
Expand All @@ -107,15 +101,14 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error

// Create a new BedrockRuntime client
a.client = bedrockruntime.New(sess)
a.language = language
a.model = GetModelOrDefault(config.GetModel())
a.temperature = config.GetTemperature()

return nil
}

// GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {

// Prepare the input data for the model invocation
request := map[string]interface{}{
Expand Down Expand Up @@ -152,44 +145,6 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string,
return output.Completion, nil
}

// Parse generates a completion for the provided prompt using the Amazon Bedrock model.
func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ")
// Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)

if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
response, err := cache.Load(cacheKey)
if err != nil {
return "", err
}

if response != "" {
output, err := base64.StdEncoding.DecodeString(response)
if err != nil {
color.Red("error decoding cached data: %v", err)
return "", nil
}
return string(output), nil
}
}

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)

if err != nil {
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", nil
}

return response, nil
}

// GetName returns the name of the AmazonBedRockClient.
func (a *AmazonBedRockClient) GetName() string {
return "amazonbedrock"
Expand Down
49 changes: 7 additions & 42 deletions pkg/ai/amazonsagemaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,18 @@ package ai

import (
"context"
"encoding/base64"
"fmt"
"strings"

"encoding/json"

"github.com/fatih/color"
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemakerruntime"
)

type SageMakerAIClient struct {
nopCloser

client *sagemakerruntime.SageMakerRuntime
language string
model string
temperature float32
endpoint string
Expand Down Expand Up @@ -63,15 +57,14 @@ type Parameters struct {
Temperature float64 `json:"temperature"`
}

func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
func (c *SageMakerAIClient) Configure(config IAIConfig) error {

// Create a new AWS session
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(config.GetProviderRegion())},
SharedConfigState: session.SharedConfigEnable,
}))

c.language = language
// Create a new SageMaker runtime client
c.client = sagemakerruntime.New(sess)
c.model = config.GetModel()
Expand All @@ -82,18 +75,13 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
return nil
}

func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
func (c *SageMakerAIClient) GetCompletion(_ context.Context, prompt string) (string, error) {
// Create a completion request

if len(promptTmpl) == 0 {
promptTmpl = PromptMap["default"]
}

request := Request{
Inputs: [][]Message{
{
{Role: "system", Content: "DEFAULT_PROMPT"},
{Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)},
{Role: "user", Content: prompt},
},
},

Expand Down Expand Up @@ -142,29 +130,6 @@ func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, pr
return content, nil
}

func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
// parse the text with the AI backend
inputKey := strings.Join(prompt, " ")
// Check for cached data
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
color.Red("error getting completion: %v", err)
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", err
}

return response, nil
}

func (a *SageMakerAIClient) GetName() string {
func (c *SageMakerAIClient) GetName() string {
return "amazonsagemaker"
}
56 changes: 6 additions & 50 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,20 @@ package ai

import (
"context"
"encoding/base64"
"errors"
"fmt"
"strings"

"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/fatih/color"

"github.com/sashabaranov/go-openai"
)

type AzureAIClient struct {
nopCloser

client *openai.Client
language string
model string
temperature float32
}

func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
func (c *AzureAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
baseURL := config.GetBaseURL()
engine := config.GetEngine()
Expand All @@ -40,21 +33,20 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
if client == nil {
return errors.New("error creating Azure OpenAI client")
}
c.language = lang
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
// Create a completion request
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: c.model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf(default_prompt, c.language, prompt),
Content: prompt,
},
},
Temperature: c.temperature,
Expand All @@ -65,42 +57,6 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
return resp.Choices[0].Message.Content, nil
}

func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ")
// Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)

if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
response, err := cache.Load(cacheKey)
if err != nil {
return "", err
}

if response != "" {
output, err := base64.StdEncoding.DecodeString(response)
if err != nil {
color.Red("error decoding cached data: %v", err)
return "", nil
}
return string(output), nil
}
}

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", nil
}

return response, nil
}

func (a *AzureAIClient) GetName() string {
func (c *AzureAIClient) GetName() string {
return "azureopenai"
}
Loading

0 comments on commit e7d4149

Please sign in to comment.