Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added Google GenAI client; simplified IAI/clients API surface. #829

Merged
merged 5 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/analyze/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ var AnalyzeCmd = &cobra.Command{
color.Red("Error: %v", err)
os.Exit(1)
}
defer config.Close()

config.RunAnalysis()

Expand Down
6 changes: 3 additions & 3 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,15 @@ func init() {
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for endpointName
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, e.g. `endpoint-xxxxxxxxxxxx` (only for amazonbedrock, amazonsagemaker backends)")
// add flag for topP
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
// max tokens
addCmd.Flags().IntVarP(&maxTokens, "maxtokens", "l", 2048, "Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
//add flag for amazonbedrock region name
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name")
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)")
}
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1
github.com/aws/aws-sdk-go v1.49.15
github.com/cohere-ai/cohere-go v0.2.0
github.com/google/generative-ai-go v0.5.0
github.com/olekukonko/tablewriter v0.0.5
google.golang.org/api v0.155.0
sigs.k8s.io/controller-runtime v0.16.3
Expand All @@ -39,9 +40,11 @@ require (

require (
cloud.google.com/go v0.110.10 // indirect
cloud.google.com/go/ai v0.3.0 // indirect
cloud.google.com/go/compute v1.23.3 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.5 // indirect
cloud.google.com/go/longrunning v0.5.4 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ cloud.google.com/go/accesscontextmanager v1.3.0/go.mod h1:TgCBehyr5gNMz7ZaH9xubp
cloud.google.com/go/accesscontextmanager v1.4.0/go.mod h1:/Kjh7BBu/Gh83sv+K60vN9QE5NJcd80sU33vIe2IFPE=
cloud.google.com/go/accesscontextmanager v1.6.0/go.mod h1:8XCvZWfYw3K/ji0iVnp+6pu7huxoQTLmxAbVjbloTtM=
cloud.google.com/go/accesscontextmanager v1.7.0/go.mod h1:CEGLewx8dwa33aDAZQujl7Dx+uYhS0eay198wB/VumQ=
cloud.google.com/go/ai v0.3.0 h1:M617N0brv+XFch2KToZUhv6ggzgFZMUnmDkNQjW2pYg=
cloud.google.com/go/ai v0.3.0/go.mod h1:dTuQIBA8Kljuas5z1WNot1QZOl476A9TsFqEi6pzJlI=
cloud.google.com/go/aiplatform v1.22.0/go.mod h1:ig5Nct50bZlzV6NvKaTwmplLLddFx0YReh9WfTO5jKw=
cloud.google.com/go/aiplatform v1.24.0/go.mod h1:67UUvRBKG6GTayHKV8DBv2RtR1t93YRu5B1P3x99mYY=
cloud.google.com/go/aiplatform v1.27.0/go.mod h1:Bvxqtl40l0WImSb04d0hXFU7gDOiq9jQmorivIiWcKg=
Expand Down Expand Up @@ -351,6 +353,8 @@ cloud.google.com/go/logging v1.7.0/go.mod h1:3xjP2CjkM3ZkO73aj4ASA5wRPGGCRrPIAeN
cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE=
cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc=
cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo=
cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg=
cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI=
cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE=
cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM=
cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA=
Expand Down Expand Up @@ -935,6 +939,8 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/generative-ai-go v0.5.0 h1:PfzPuSGdsmcSyPG7RIoijcKWZ7/x2kvgyNryvmXMUmA=
github.com/google/generative-ai-go v0.5.0/go.mod h1:8fXQk4w+eyTzFokGGJrBFL0/xwXqm3QNhTqOWyX11zs=
github.com/google/gnostic v0.7.0 h1:d7EpuFp8vVdML+y0JJJYiKeOLjKTdH/GvVkLOBWqJpw=
github.com/google/gnostic v0.7.0/go.mod h1:IAcUyMl6vtC95f60EZ8oXyqTsOersP6HbwjeG7EyDPM=
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 h1:0VpGH+cDhbDtdcweoyCVsF3fhN8kejK6rFe/2FFX2nU=
Expand Down
55 changes: 5 additions & 50 deletions pkg/ai/amazonbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@ package ai

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"strings"

"github.com/fatih/color"

"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -19,8 +12,9 @@ import (

// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service.
type AmazonBedRockClient struct {
nopCloser

client *bedrockruntime.BedrockRuntime
language string
model string
temperature float32
}
Expand Down Expand Up @@ -91,8 +85,8 @@ func GetRegionOrDefault(region string) string {
return BEDROCK_DEFAULT_REGION
}

// Configure configures the AmazonBedRockClient with the provided configuration and language.
func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error {
// Configure configures the AmazonBedRockClient with the provided configuration.
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {

// Create a new AWS session
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
Expand All @@ -107,15 +101,14 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error

// Create a new BedrockRuntime client
a.client = bedrockruntime.New(sess)
a.language = language
a.model = GetModelOrDefault(config.GetModel())
a.temperature = config.GetTemperature()

return nil
}

// GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {

// Prepare the input data for the model invocation
request := map[string]interface{}{
Expand Down Expand Up @@ -152,44 +145,6 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string,
return output.Completion, nil
}

// Parse generates a completion for the provided prompt using the Amazon Bedrock model.
func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ")
// Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)

if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
response, err := cache.Load(cacheKey)
if err != nil {
return "", err
}

if response != "" {
output, err := base64.StdEncoding.DecodeString(response)
if err != nil {
color.Red("error decoding cached data: %v", err)
return "", nil
}
return string(output), nil
}
}

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)

if err != nil {
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", nil
}

return response, nil
}

// GetName returns the name of the AmazonBedRockClient.
func (a *AmazonBedRockClient) GetName() string {
return "amazonbedrock"
Expand Down
49 changes: 7 additions & 42 deletions pkg/ai/amazonsagemaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,18 @@ package ai

import (
"context"
"encoding/base64"
"fmt"
"strings"

"encoding/json"

"github.com/fatih/color"
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemakerruntime"
)

type SageMakerAIClient struct {
nopCloser

client *sagemakerruntime.SageMakerRuntime
language string
model string
temperature float32
endpoint string
Expand Down Expand Up @@ -63,15 +57,14 @@ type Parameters struct {
Temperature float64 `json:"temperature"`
}

func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
func (c *SageMakerAIClient) Configure(config IAIConfig) error {

// Create a new AWS session
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(config.GetProviderRegion())},
SharedConfigState: session.SharedConfigEnable,
}))

c.language = language
// Create a new SageMaker runtime client
c.client = sagemakerruntime.New(sess)
c.model = config.GetModel()
Expand All @@ -82,18 +75,13 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
return nil
}

func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
func (c *SageMakerAIClient) GetCompletion(_ context.Context, prompt string) (string, error) {
// Create a completion request

if len(promptTmpl) == 0 {
promptTmpl = PromptMap["default"]
}

request := Request{
Inputs: [][]Message{
{
{Role: "system", Content: "DEFAULT_PROMPT"},
{Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)},
{Role: "user", Content: prompt},
},
},

Expand Down Expand Up @@ -142,29 +130,6 @@ func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, pr
return content, nil
}

func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
// parse the text with the AI backend
inputKey := strings.Join(prompt, " ")
// Check for cached data
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
color.Red("error getting completion: %v", err)
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", err
}

return response, nil
}

func (a *SageMakerAIClient) GetName() string {
func (c *SageMakerAIClient) GetName() string {
return "amazonsagemaker"
}
56 changes: 6 additions & 50 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,20 @@ package ai

import (
"context"
"encoding/base64"
"errors"
"fmt"
"strings"

"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/fatih/color"

"github.com/sashabaranov/go-openai"
)

type AzureAIClient struct {
nopCloser

client *openai.Client
language string
model string
temperature float32
}

func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
func (c *AzureAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
baseURL := config.GetBaseURL()
engine := config.GetEngine()
Expand All @@ -40,21 +33,20 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
if client == nil {
return errors.New("error creating Azure OpenAI client")
}
c.language = lang
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
// Create a completion request
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: c.model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf(default_prompt, c.language, prompt),
Content: prompt,
},
},
Temperature: c.temperature,
Expand All @@ -65,42 +57,6 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
return resp.Choices[0].Message.Content, nil
}

func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ")
// Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)

if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
response, err := cache.Load(cacheKey)
if err != nil {
return "", err
}

if response != "" {
output, err := base64.StdEncoding.DecodeString(response)
if err != nil {
color.Red("error decoding cached data: %v", err)
return "", nil
}
return string(output), nil
}
}

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", nil
}

return response, nil
}

func (a *AzureAIClient) GetName() string {
func (c *AzureAIClient) GetName() string {
return "azureopenai"
}
Loading