Skip to content

Commit

Permalink
feat: amazonsagemaker AI provider (#731)
Browse files Browse the repository at this point in the history
* feat(amazonsagemaker): Add AmazonSageMaker AI provider

Co-authored-by: NAME 18630245+zaremb@users.noreply.github.com
Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com>

* feat(amazonsagemaker): Add AmazonSageMaker AI provider

Co-authored-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com>
Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com>

* feat(auth): add top p and max tokens to auth and use them in sagemaker backend

Signed-off-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com>

* feat: Updates SageMaker docs, validate topP, ident

Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com>

* feat: list of passwordlessProviders

Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com>

* feat: returns err

Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com>

* fix: remove log.Fatal(err)

Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com>

---------

Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com>
Signed-off-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com>
Co-authored-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com>
  • Loading branch information
dkuroczk and zaremb authored Nov 5, 2023
1 parent 188a8a2 commit ccef7f6
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 4 deletions.
62 changes: 60 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,6 @@ In addition to this you will need to set the follow local environmental variable
k8sgpt auth add --backend amazonbedrock --model anthropic.claude-v2
```

TODO: Currently access key will be requested in the CLI, you can enter anything into this.

#### Usage

```
Expand All @@ -398,6 +396,66 @@ k8sgpt analyze -e -b amazonbedrock
You're right, I don't have enough context to determine if a StatefulSet is correctly configured to use a non-existent service. A StatefulSet manages Pods with persistent storage, and the Pods are created from the same spec. The service name referenced in the StatefulSet configuration would need to match an existing Kubernetes service for the Pods to connect to. Without more details on the specific StatefulSet and environment, I can't confirm whether the configuration is valid or not.
```

</details>

<details>
<summary>Amazon SageMaker Provider</summary>

#### Prerequisites

1. **AWS CLI Configuration**: Make sure you have the AWS Command Line Interface (CLI) configured on your machine. If you haven't already configured the AWS CLI, you can follow the official AWS documentation for instructions on how to do it: [AWS CLI Configuration Guide](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html).

2. **SageMaker Instance**: You need to have an Amazon SageMaker instance set up. If you don't have one already, you can follow the step-by-step instructions provided in this repository for creating a SageMaker instance: [llm-sagemaker-jumpstart-cdk](https://github.com/zaremb/llm-sagemaker-jumpstart-cdk).

#### Backend Configuration

To add amazonsagemaker backend two parameters are required:

* `--endpointname` Amazon SageMaker endpoint name.
* `--providerRegion` AWS region where SageMaker instance is created. `k8sgpt` uses this region to connect to SageMaker (not the one defined with AWS CLI or environment variables )

To add amazonsagemaker as a backend run:

```bash
k8sgpt auth add --backend amazonsagemaker --providerRegion eu-west-1 --endpointname endpoint-xxxxxxxxxx
```

#### Optional params

Optionally, when adding the backend and later by changing the configuration file, you can set the following parameters:

`-l, --maxtokens int` Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length (default 2048)

`-t, --temperature float32` The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random) (default 0.7)

`-c, --topp float32` Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability. (default 0.5)

To make amazonsagemaker as a default backend run:

```bash
k8sgpt auth default -p amazonsagemaker
```

#### AmazonSageMaker Usage

```bash
./k8sgpt analyze -e -b amazonsagemaker
100% |███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| (1/1, 14 it/min)
AI Provider: amazonsagemaker

0 default/nginx(nginx)
- Error: Back-off pulling image "nginxx"
Error: Back-off pulling image "nginxx"

Solution:

1. Check if the image exists in the registry by running `docker image ls nginxx`.
2. If the image is not found, try pulling it by running `docker pull nginxx`.
3. If the image is still not available, check if there are any network issues by running `docker network inspect` and `docker network list`.
4. If the issue persists, try restarting the Docker daemon by running `sudo service docker restart`.
```
</details>
<details>
Expand Down
18 changes: 17 additions & 1 deletion cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ var addCmd = &cobra.Command{
_ = cmd.MarkFlagRequired("engine")
_ = cmd.MarkFlagRequired("baseurl")
}

if strings.ToLower(backend) == "amazonsagemaker" {
_ = cmd.MarkFlagRequired("endpointname")
_ = cmd.MarkFlagRequired("providerRegion")
}
},
Run: func(cmd *cobra.Command, args []string) {

Expand Down Expand Up @@ -90,6 +93,10 @@ var addCmd = &cobra.Command{
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}
if topP > 1.0 || topP < 0.0 {
color.Red("Error: topP ranges from 0 to 1.")
os.Exit(1)
}

if ai.NeedPassword(backend) && password == "" {
fmt.Printf("Enter %s Key: ", backend)
Expand All @@ -108,9 +115,12 @@ var addCmd = &cobra.Command{
Model: model,
Password: password,
BaseURL: baseURL,
EndpointName: endpointName,
Engine: engine,
Temperature: temperature,
ProviderRegion: providerRegion,
TopP: topP,
MaxTokens: maxTokens,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -138,6 +148,12 @@ func init() {
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for endpointName
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
// add flag for topP
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
// max tokens
addCmd.Flags().IntVarP(&maxTokens, "maxtokens", "l", 2048, "Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
Expand Down
3 changes: 3 additions & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ var (
backend string
password string
baseURL string
endpointName string
model string
engine string
temperature float32
providerRegion string
topP float32
maxTokens int
)

var configAI ai.AIConfiguration
Expand Down
170 changes: 170 additions & 0 deletions pkg/ai/amazonsagemaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
Copyright 2023 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ai

import (
"context"
"encoding/base64"
"fmt"
"strings"

"encoding/json"

"github.com/fatih/color"
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemakerruntime"
)

type SageMakerAIClient struct {
client *sagemakerruntime.SageMakerRuntime
language string
model string
temperature float32
endpoint string
topP float32
maxTokens int
}

type Generations []struct {
Generation struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"generation"`
}

type Request struct {
Inputs [][]Message `json:"inputs"`
Parameters Parameters `json:"parameters"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type Parameters struct {
MaxNewTokens int `json:"max_new_tokens"`
TopP float64 `json:"top_p"`
Temperature float64 `json:"temperature"`
}

func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {

// Create a new AWS session
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(config.GetProviderRegion())},
SharedConfigState: session.SharedConfigEnable,
}))

c.language = language
// Create a new SageMaker runtime client
c.client = sagemakerruntime.New(sess)
c.model = config.GetModel()
c.endpoint = config.GetEndpointName()
c.temperature = config.GetTemperature()
c.maxTokens = config.GetMaxTokens()
c.topP = config.GetTopP()
return nil
}

func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
// Create a completion request

if len(promptTmpl) == 0 {
promptTmpl = PromptMap["default"]
}

request := Request{
Inputs: [][]Message{
{
{Role: "system", Content: "DEFAULT_PROMPT"},
{Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)},
},
},

Parameters: Parameters{
MaxNewTokens: int(c.maxTokens),
TopP: float64(c.topP),
Temperature: float64(c.temperature),
},
}

// Convert request to []byte
bytesData, err := json.Marshal(request)
if err != nil {
return "", err
}

// Create an input object
input := &sagemakerruntime.InvokeEndpointInput{
Body: bytesData,
EndpointName: aws.String(c.endpoint),
ContentType: aws.String("application/json"), // Set the content type as per your model's requirements
Accept: aws.String("application/json"), // Set the accept type as per your model's requirements
CustomAttributes: aws.String("accept_eula=true"),
}

// Call the InvokeEndpoint function
result, err := c.client.InvokeEndpoint(input)
if err != nil {
return "", err
}

// // Define a slice of Generations
var generations Generations

err = json.Unmarshal([]byte(string(result.Body)), &generations)
if err != nil {
return "", err
}
// Check for length of generations
if len(generations) != 1 {
return "", fmt.Errorf("Expected exactly one generation, but got %d", len(generations))
}

// Access the content
content := generations[0].Generation.Content
return content, nil
}

func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
// parse the text with the AI backend
inputKey := strings.Join(prompt, " ")
// Check for cached data
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
color.Red("error getting completion: %v", err)
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", err
}

return response, nil
}

func (a *SageMakerAIClient) GetName() string {
return "amazonsagemaker"
}
29 changes: 28 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
&NoOpAIClient{},
&CohereClient{},
&AmazonBedRockClient{},
&SageMakerAIClient{},
}
Backends = []string{
"openai",
Expand All @@ -35,6 +36,7 @@ var (
"noopai",
"cohere",
"amazonbedrock",
"amazonsagemaker",
}
)

Expand All @@ -49,9 +51,12 @@ type IAIConfig interface {
GetPassword() string
GetModel() string
GetBaseURL() string
GetEndpointName() string
GetEngine() string
GetTemperature() float32
GetProviderRegion() string
GetTopP() float32
GetMaxTokens() int
}

func NewClient(provider string) IAI {
Expand All @@ -74,15 +79,30 @@ type AIProvider struct {
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
}

func (p *AIProvider) GetBaseURL() string {
return p.BaseURL
}

func (p *AIProvider) GetEndpointName() string {
return p.EndpointName
}

func (p *AIProvider) GetTopP() float32 {
return p.TopP
}

func (p *AIProvider) GetMaxTokens() int {
return p.MaxTokens
}

func (p *AIProvider) GetPassword() string {
return p.Password
}
Expand All @@ -102,6 +122,13 @@ func (p *AIProvider) GetProviderRegion() string {
return p.ProviderRegion
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock"}

func NeedPassword(backend string) bool {
return backend != "localai"
for _, b := range passwordlessProviders {
if b == backend {
return false
}
}
return true
}

0 comments on commit ccef7f6

Please sign in to comment.