Skip to content

Commit

Permalink
feat: add huggingface provider (#893)
Browse files Browse the repository at this point in the history
* feat: add huggingface ai provider

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

* chore: update readme

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

* fix: set huggingface maxtokens default to 500, use ptr instead of pointer

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

---------

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>
  • Loading branch information
JuHyung-Son committed Jan 28, 2024
1 parent 483a9da commit 2fd476e
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ Unused:
> amazonbedrock
> amazonsagemaker
> google
> huggingface
> noopai
```

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/gookit/color v1.5.4 // indirect
github.com/hupe1980/go-huggingface v0.0.15 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jpillora/backoff v1.0.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,8 @@ github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq
github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU=
github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
github.com/hupe1980/go-huggingface v0.0.15 h1:tTWmUGGunC/BYz4hrwS8SSVtMYVYjceG2uhL8HxeXvw=
github.com/hupe1980/go-huggingface v0.0.15/go.mod h1:IRvsik3+b9BJyw9hCfw1arI6gDObcVto1UA8f3kt8mM=
github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
Expand Down
59 changes: 59 additions & 0 deletions pkg/ai/huggingface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package ai

import (
"context"
"github.com/hupe1980/go-huggingface"
"k8s.io/utils/ptr"
)

const huggingfaceAIClientName = "huggingface"

type HuggingfaceClient struct {
nopCloser

client *huggingface.InferenceClient
model string
topP float32
temperature float32
maxTokens int
}

func (c *HuggingfaceClient) Configure(config IAIConfig) error {
token := config.GetPassword()

client := huggingface.NewInferenceClient(token)

c.client = client
c.model = config.GetModel()
c.topP = config.GetTopP()
c.temperature = config.GetTemperature()
if config.GetMaxTokens() > 500 {
c.maxTokens = 500
} else {
c.maxTokens = config.GetMaxTokens()
}
return nil
}

func (c *HuggingfaceClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
resp, err := c.client.Conversational(ctx, &huggingface.ConversationalRequest{
Inputs: huggingface.ConverstationalInputs{
Text: prompt,
},
Model: c.model,
Parameters: huggingface.ConversationalParameters{
TopP: ptr.To[float64](float64(c.topP)),
Temperature: ptr.To[float64](float64(c.temperature)),
MaxLength: &c.maxTokens,
},
Options: huggingface.Options{
WaitForModel: ptr.To[bool](true),
},
})
if err != nil {
return "", err
}
return resp.GeneratedText, nil
}

func (c *HuggingfaceClient) GetName() string { return huggingfaceAIClientName }
2 changes: 2 additions & 0 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
&AmazonBedRockClient{},
&SageMakerAIClient{},
&GoogleGenAIClient{},
&HuggingfaceClient{},
}
Backends = []string{
openAIClientName,
Expand All @@ -37,6 +38,7 @@ var (
amazonsagemakerAIClientName,
googleAIClientName,
noopAIClientName,
huggingfaceAIClientName,
}
)

Expand Down

0 comments on commit 2fd476e

Please sign in to comment.