Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: amazonsagemaker AI provider #731

Merged
merged 9 commits into from
Nov 5, 2023
62 changes: 60 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,6 @@ In addition to this you will need to set the follow local environmental variable
k8sgpt auth add --backend amazonbedrock --model anthropic.claude-v2
```

TODO: Currently access key will be requested in the CLI, you can enter anything into this.

#### Usage

```
Expand All @@ -398,6 +396,66 @@ k8sgpt analyze -e -b amazonbedrock

You're right, I don't have enough context to determine if a StatefulSet is correctly configured to use a non-existent service. A StatefulSet manages Pods with persistent storage, and the Pods are created from the same spec. The service name referenced in the StatefulSet configuration would need to match an existing Kubernetes service for the Pods to connect to. Without more details on the specific StatefulSet and environment, I can't confirm whether the configuration is valid or not.
```

</details>

<details>
<summary>Amazon SageMaker Provider</summary>

#### Prerequisites

1. **AWS CLI Configuration**: Make sure you have the AWS Command Line Interface (CLI) configured on your machine. If you haven't already configured the AWS CLI, you can follow the official AWS documentation for instructions on how to do it: [AWS CLI Configuration Guide](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html).

2. **SageMaker Instance**: You need to have an Amazon SageMaker instance set up. If you don't have one already, you can follow the step-by-step instructions provided in this repository for creating a SageMaker instance: [llm-sagemaker-jumpstart-cdk](https://github.com/zaremb/llm-sagemaker-jumpstart-cdk).

#### Backend Configuration

To add amazonsagemaker backend two parameters are required:

* `--endpointname` Amazon SageMaker endpoint name.
* `--providerRegion` AWS region where SageMaker instance is created. `k8sgpt` uses this region to connect to SageMaker (not the one defined with AWS CLI or environment variables )

To add amazonsagemaker as a backend run:

```bash
k8sgpt auth add --backend amazonsagemaker --providerRegion eu-west-1 --endpointname endpoint-xxxxxxxxxx
```

#### Optional params

Optionally, when adding the backend and later by changing the configuration file, you can set the following parameters:

`-l, --maxtokens int` Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length (default 2048)

`-t, --temperature float32` The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random) (default 0.7)

`-c, --topp float32` Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability. (default 0.5)

To make amazonsagemaker as a default backend run:

```bash
k8sgpt auth default -p amazonsagemaker
```

#### AmazonSageMaker Usage

```bash
./k8sgpt analyze -e -b amazonsagemaker
100% |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| (1/1, 14 it/min)
AI Provider: amazonsagemaker

0 default/nginx(nginx)
- Error: Back-off pulling image "nginxx"
Error: Back-off pulling image "nginxx"

Solution:

1. Check if the image exists in the registry by running `docker image ls nginxx`.
2. If the image is not found, try pulling it by running `docker pull nginxx`.
3. If the image is still not available, check if there are any network issues by running `docker network inspect` and `docker network list`.
4. If the issue persists, try restarting the Docker daemon by running `sudo service docker restart`.
```

</details>

<details>
Expand Down
18 changes: 17 additions & 1 deletion cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ var addCmd = &cobra.Command{
_ = cmd.MarkFlagRequired("engine")
_ = cmd.MarkFlagRequired("baseurl")
}

if strings.ToLower(backend) == "amazonsagemaker" {
_ = cmd.MarkFlagRequired("endpointname")
_ = cmd.MarkFlagRequired("providerRegion")
}
},
Run: func(cmd *cobra.Command, args []string) {

Expand Down Expand Up @@ -90,6 +93,10 @@ var addCmd = &cobra.Command{
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}
if topP > 1.0 || topP < 0.0 {
color.Red("Error: topP ranges from 0 to 1.")
os.Exit(1)
}

if ai.NeedPassword(backend) && password == "" {
fmt.Printf("Enter %s Key: ", backend)
Expand All @@ -108,9 +115,12 @@ var addCmd = &cobra.Command{
Model: model,
Password: password,
BaseURL: baseURL,
EndpointName: endpointName,
Engine: engine,
Temperature: temperature,
ProviderRegion: providerRegion,
TopP: topP,
MaxTokens: maxTokens,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -138,6 +148,12 @@ func init() {
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for endpointName
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
// add flag for topP
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
// max tokens
addCmd.Flags().IntVarP(&maxTokens, "maxtokens", "l", 2048, "Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
Expand Down
3 changes: 3 additions & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ var (
backend string
password string
baseURL string
endpointName string
model string
engine string
temperature float32
providerRegion string
topP float32
maxTokens int
)

var configAI ai.AIConfiguration
Expand Down
170 changes: 170 additions & 0 deletions pkg/ai/amazonsagemaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
Copyright 2023 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ai

import (
"context"
"encoding/base64"
"fmt"
"strings"

"encoding/json"

"github.com/fatih/color"
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemakerruntime"
)

type SageMakerAIClient struct {
client *sagemakerruntime.SageMakerRuntime
language string
model string
temperature float32
endpoint string
topP float32
maxTokens int
}

type Generations []struct {
Generation struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"generation"`
}

type Request struct {
Inputs [][]Message `json:"inputs"`
Parameters Parameters `json:"parameters"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type Parameters struct {
MaxNewTokens int `json:"max_new_tokens"`
TopP float64 `json:"top_p"`
Temperature float64 `json:"temperature"`
}

func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {

// Create a new AWS session
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(config.GetProviderRegion())},
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
SharedConfigState: session.SharedConfigEnable,
}))

c.language = language
// Create a new SageMaker runtime client
c.client = sagemakerruntime.New(sess)
c.model = config.GetModel()
c.endpoint = config.GetEndpointName()
c.temperature = config.GetTemperature()
c.maxTokens = config.GetMaxTokens()
c.topP = config.GetTopP()
return nil
}

func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
// Create a completion request

if len(promptTmpl) == 0 {
promptTmpl = PromptMap["default"]
}

request := Request{
Inputs: [][]Message{
{
{Role: "system", Content: "DEFAULT_PROMPT"},
{Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)},
},
},

Parameters: Parameters{
MaxNewTokens: int(c.maxTokens),
TopP: float64(c.topP),
Temperature: float64(c.temperature),
},
}

// Convert request to []byte
bytesData, err := json.Marshal(request)
if err != nil {
return "", err
}

// Create an input object
input := &sagemakerruntime.InvokeEndpointInput{
Body: bytesData,
EndpointName: aws.String(c.endpoint),
ContentType: aws.String("application/json"), // Set the content type as per your model's requirements
Accept: aws.String("application/json"), // Set the accept type as per your model's requirements
CustomAttributes: aws.String("accept_eula=true"),
}

// Call the InvokeEndpoint function
result, err := c.client.InvokeEndpoint(input)
if err != nil {
return "", err
}

// // Define a slice of Generations
var generations Generations

err = json.Unmarshal([]byte(string(result.Body)), &generations)
if err != nil {
return "", err
}
// Check for length of generations
if len(generations) != 1 {
return "", fmt.Errorf("Expected exactly one generation, but got %d", len(generations))
}

// Access the content
content := generations[0].Generation.Content
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
return content, nil
}

func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
// parse the text with the AI backend
inputKey := strings.Join(prompt, " ")
// Check for cached data
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
color.Red("error getting completion: %v", err)
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", err
}

return response, nil
}

func (a *SageMakerAIClient) GetName() string {
return "amazonsagemaker"
}
29 changes: 28 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
&NoOpAIClient{},
&CohereClient{},
&AmazonBedRockClient{},
&SageMakerAIClient{},
}
Backends = []string{
"openai",
Expand All @@ -35,6 +36,7 @@ var (
"noopai",
"cohere",
"amazonbedrock",
"amazonsagemaker",
}
)

Expand All @@ -49,9 +51,12 @@ type IAIConfig interface {
GetPassword() string
GetModel() string
GetBaseURL() string
GetEndpointName() string
GetEngine() string
GetTemperature() float32
GetProviderRegion() string
GetTopP() float32
GetMaxTokens() int
}

func NewClient(provider string) IAI {
Expand All @@ -74,15 +79,30 @@ type AIProvider struct {
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
}

func (p *AIProvider) GetBaseURL() string {
return p.BaseURL
}

func (p *AIProvider) GetEndpointName() string {
return p.EndpointName
}

func (p *AIProvider) GetTopP() float32 {
return p.TopP
}

func (p *AIProvider) GetMaxTokens() int {
return p.MaxTokens
}

func (p *AIProvider) GetPassword() string {
return p.Password
}
Expand All @@ -102,6 +122,13 @@ func (p *AIProvider) GetProviderRegion() string {
return p.ProviderRegion
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock"}

func NeedPassword(backend string) bool {
return backend != "localai"
for _, b := range passwordlessProviders {
if b == backend {
return false
}
}
return true
}