From 2fd476e12624e30570c0819594f2668f720381d6 Mon Sep 17 00:00:00 2001 From: JuHyung Son Date: Mon, 29 Jan 2024 00:41:49 +0900 Subject: [PATCH] feat: add huggingface provider (#893) * feat: add huggingface ai provider Signed-off-by: JuHyung-Son * chore: update readme Signed-off-by: JuHyung-Son * fix: set huggingface maxtokens default to 500, use ptr instead of pointer Signed-off-by: JuHyung-Son --------- Signed-off-by: JuHyung-Son --- README.md | 1 + go.mod | 1 + go.sum | 2 ++ pkg/ai/huggingface.go | 59 +++++++++++++++++++++++++++++++++++++++++++ pkg/ai/iai.go | 2 ++ 5 files changed, 65 insertions(+) create mode 100644 pkg/ai/huggingface.go diff --git a/README.md b/README.md index dcac0990a7..58052828b2 100644 --- a/README.md +++ b/README.md @@ -309,6 +309,7 @@ Unused: > amazonbedrock > amazonsagemaker > google +> huggingface > noopai ``` diff --git a/go.mod b/go.mod index 68683ab967..d6eb853d16 100644 --- a/go.mod +++ b/go.mod @@ -73,6 +73,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/gookit/color v1.5.4 // indirect + github.com/hupe1980/go-huggingface v0.0.15 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect diff --git a/go.sum b/go.sum index 54127c7b0c..8650462b93 100644 --- a/go.sum +++ b/go.sum @@ -1672,6 +1672,8 @@ github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU= github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/hupe1980/go-huggingface v0.0.15 h1:tTWmUGGunC/BYz4hrwS8SSVtMYVYjceG2uhL8HxeXvw= +github.com/hupe1980/go-huggingface v0.0.15/go.mod h1:IRvsik3+b9BJyw9hCfw1arI6gDObcVto1UA8f3kt8mM= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/pkg/ai/huggingface.go b/pkg/ai/huggingface.go new file mode 100644 index 0000000000..68dcc8ccb4 --- /dev/null +++ b/pkg/ai/huggingface.go @@ -0,0 +1,59 @@ +package ai + +import ( + "context" + "github.com/hupe1980/go-huggingface" + "k8s.io/utils/ptr" +) + +const huggingfaceAIClientName = "huggingface" + +type HuggingfaceClient struct { + nopCloser + + client *huggingface.InferenceClient + model string + topP float32 + temperature float32 + maxTokens int +} + +func (c *HuggingfaceClient) Configure(config IAIConfig) error { + token := config.GetPassword() + + client := huggingface.NewInferenceClient(token) + + c.client = client + c.model = config.GetModel() + c.topP = config.GetTopP() + c.temperature = config.GetTemperature() + if config.GetMaxTokens() > 500 { + c.maxTokens = 500 + } else { + c.maxTokens = config.GetMaxTokens() + } + return nil +} + +func (c *HuggingfaceClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + resp, err := c.client.Conversational(ctx, &huggingface.ConversationalRequest{ + Inputs: huggingface.ConverstationalInputs{ + Text: prompt, + }, + Model: c.model, + Parameters: huggingface.ConversationalParameters{ + TopP: ptr.To[float64](float64(c.topP)), + Temperature: ptr.To[float64](float64(c.temperature)), + MaxLength: &c.maxTokens, + }, + Options: huggingface.Options{ + WaitForModel: ptr.To[bool](true), + }, + }) + if err != nil { + return "", err + } + return resp.GeneratedText, nil +} + +func (c *HuggingfaceClient) GetName() string { return huggingfaceAIClientName } diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 7ad14852cf..99de8e3a40 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -27,6 +27,7 @@ var ( &AmazonBedRockClient{}, &SageMakerAIClient{}, &GoogleGenAIClient{}, + &HuggingfaceClient{}, } Backends = []string{ openAIClientName, @@ -37,6 +38,7 @@ var ( amazonsagemakerAIClientName, googleAIClientName, noopAIClientName, + huggingfaceAIClientName, } )