Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: amazonsagemaker AI provider #731

Merged
merged 9 commits into from
Nov 5, 2023
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,58 @@ k8sgpt analyze -e -b amazonbedrock

You're right, I don't have enough context to determine if a StatefulSet is correctly configured to use a non-existent service. A StatefulSet manages Pods with persistent storage, and the Pods are created from the same spec. The service name referenced in the StatefulSet configuration would need to match an existing Kubernetes service for the Pods to connect to. Without more details on the specific StatefulSet and environment, I can't confirm whether the configuration is valid or not.
```

</details>

<details>
<summary>Amazon SageMaker Provider</summary>

<em>Prerequisites</em>

1. **AWS CLI Configuration**: Make sure you have the AWS Command Line Interface (CLI) configured on your machine. If you haven't already configured the AWS CLI, you can follow the official AWS documentation for instructions on how to do it: [AWS CLI Configuration Guide](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html).

2. **SageMaker Instance**: You need to have an Amazon SageMaker instance set up. If you don't have one already, you can follow the step-by-step instructions provided in this repository for creating a SageMaker instance: [llm-sagemaker-jumpstart-cdk](https://github.com/zaremb/llm-sagemaker-jumpstart-cdk).

3. **Backend Configuration**:
To add amazonsagemaker backend two parameters are required:

* Amazon SageMaker endpoint name: You'll need the name of the SageMaker endpoint.
* AWS region where your SageMaker instance is created

To add amazonsagemaker as a backend run:

```bash
k8sgpt auth add --backend amazonsagemaker --providerRegion eu-west-1 --endpointname endpoint-xxxxxxxxxx
```

***Note**:
TODO: Currently access key will be requested in the CLI, you can enter anything into this.
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved

To make amazonsagemaker as a default backend run:

```bash
k8sgpt auth default -p amazonsagemaker
```

#### AmazonSageMaker Usage

```bash
./k8sgpt analyze -e -b amazonsagemaker
100% |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| (1/1, 14 it/min)
AI Provider: amazonsagemaker

0 default/nginx(nginx)
- Error: Back-off pulling image "nginxx"
Error: Back-off pulling image "nginxx"

Solution:

1. Check if the image exists in the registry by running `docker image ls nginxx`.
2. If the image is not found, try pulling it by running `docker pull nginxx`.
3. If the image is still not available, check if there are any network issues by running `docker network inspect` and `docker network list`.
4. If the issue persists, try restarting the Docker daemon by running `sudo service docker restart`.
```

</details>

<details>
Expand Down
8 changes: 7 additions & 1 deletion cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ var addCmd = &cobra.Command{
_ = cmd.MarkFlagRequired("engine")
_ = cmd.MarkFlagRequired("baseurl")
}

if strings.ToLower(backend) == "amazonsagemaker" {
_ = cmd.MarkFlagRequired("endpointname")
_ = cmd.MarkFlagRequired("providerRegion")
}
},
Run: func(cmd *cobra.Command, args []string) {

Expand Down Expand Up @@ -108,6 +111,7 @@ var addCmd = &cobra.Command{
Model: model,
Password: password,
BaseURL: baseURL,
EndpointName: endpointName,
Engine: engine,
Temperature: temperature,
ProviderRegion: providerRegion,
Expand Down Expand Up @@ -138,6 +142,8 @@ func init() {
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for endpointName
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
Expand Down
1 change: 1 addition & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var (
backend string
password string
baseURL string
endpointName string
model string
engine string
temperature float32
Expand Down
161 changes: 161 additions & 0 deletions pkg/ai/amazonsagemaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
Copyright 2023 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ai

import (
"context"
"encoding/base64"
"fmt"
"strings"

"encoding/json"
"log"

"github.com/fatih/color"
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemakerruntime"
)

type SageMakerAIClient struct {
client *sagemakerruntime.SageMakerRuntime
language string
model string
temperature float32
endpoint string
}

const (
// SageMaker completion parameters
maxNewTokens = 256
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
top_P = 0.9
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
)

type Generation struct {
Role string `json:"role"`
Content string `json:"content"`
}

func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {

// Create a new AWS session
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(config.GetProviderRegion())},
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
SharedConfigState: session.SharedConfigEnable,
}))

c.language = language
// Create a new SageMaker runtime client
c.client = sagemakerruntime.New(sess)
c.model = config.GetModel()
c.endpoint = config.GetEndpointName()
c.temperature = config.GetTemperature()
return nil
}

func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
// Create a completion request

if len(promptTmpl) == 0 {
promptTmpl = PromptMap["default"]
}
// TODO: extract all paramseters to config
data := map[string]interface{}{
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
"inputs": []interface{}{
[]interface{}{
map[string]interface{}{
"role": "system",
"content": "DEFAULT_PROMPT",
},
map[string]interface{}{
"role": "user",
"content": fmt.Sprintf(promptTmpl, c.language, prompt),
},
},
},
"parameters": map[string]interface{}{
"max_new_tokens": maxNewTokens,
"top_p": top_P,
"temperature": c.temperature,
},
}
// Convert data to []byte
bytesData, err := json.Marshal(data)
if err != nil {
fmt.Println("Error:", err)
log.Fatal(err)
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
return "", err
}

// Create an input object
input := &sagemakerruntime.InvokeEndpointInput{
Body: bytesData,
EndpointName: aws.String(c.endpoint),
ContentType: aws.String("application/json"), // Set the content type as per your model's requirements
Accept: aws.String("application/json"), // Set the accept type as per your model's requirements
CustomAttributes: aws.String("accept_eula=true"),
}

// Call the InvokeEndpoint function
result, err := c.client.InvokeEndpoint(input)
if err != nil {
log.Fatal(err)
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
return "", err
}

// Define a slice of Generations
var generations []struct {
Generation Generation `json:"generation"`
}

err = json.Unmarshal([]byte(string(result.Body)), &generations)
if err != nil {
log.Fatal(err)
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
return "", err
}

// Access the content
content := generations[0].Generation.Content
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
return content, nil
}

func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
// parse the text with the AI backend
inputKey := strings.Join(prompt, " ")
// Check for cached data
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)

response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
color.Red("error getting completion: %v", err)
return "", err
}

err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))

if err != nil {
color.Red("error storing value to cache: %v", err)
return "", nil
dkuroczk marked this conversation as resolved.
Show resolved Hide resolved
}

return response, nil
}

func (a *SageMakerAIClient) GetName() string {
return "amazonsagemaker"
}
8 changes: 8 additions & 0 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
&NoOpAIClient{},
&CohereClient{},
&AmazonBedRockClient{},
&SageMakerAIClient{},
}
Backends = []string{
"openai",
Expand All @@ -35,6 +36,7 @@ var (
"noopai",
"cohere",
"amazonbedrock",
"amazonsagemaker",
}
)

Expand All @@ -49,6 +51,7 @@ type IAIConfig interface {
GetPassword() string
GetModel() string
GetBaseURL() string
GetEndpointName() string
GetEngine() string
GetTemperature() float32
GetProviderRegion() string
Expand All @@ -74,6 +77,7 @@ type AIProvider struct {
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
Expand All @@ -83,6 +87,10 @@ func (p *AIProvider) GetBaseURL() string {
return p.BaseURL
}

func (p *AIProvider) GetEndpointName() string {
return p.EndpointName
}

func (p *AIProvider) GetPassword() string {
return p.Password
}
Expand Down