Skip to content

Commit

Permalink
Refactor client code to make way for additionl AI API backends
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Feb 27, 2024
1 parent d2054f3 commit 046d340
Show file tree
Hide file tree
Showing 15 changed files with 341 additions and 206 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/gptscript-ai/gptscript

go 1.22.0

replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240206232711-45b6e096246a
replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185

require (
github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gptscript-ai/go-openai v0.0.0-20240206232711-45b6e096246a h1:AdBbQ1ODOYK5AwCey4VFEmKeu9gG4PCzuO80pQmgupE=
github.com/gptscript-ai/go-openai v0.0.0-20240206232711-45b6e096246a/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185 h1:+TfC9DYtWuexdL7x1lIdD1HP61IStb3ZTj/byBdiWs0=
github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/hexops/autogold v0.8.1/go.mod h1:97HLDXyG23akzAoRYJh/2OBs3kd80eHyKPvZw0S5ZBY=
github.com/hexops/autogold v1.3.1 h1:YgxF9OHWbEIUjhDbpnLhgVsjUDsiHDTyDfy2lrfdlzo=
github.com/hexops/autogold v1.3.1/go.mod h1:sQO+mQUCVfxOKPht+ipDSkJ2SCJ7BNJVHZexsXqWMx4=
Expand Down
12 changes: 10 additions & 2 deletions pkg/builtin/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@ import (
)

var (
DefaultModel = openai.DefaultModel
defaultModel = openai.DefaultModel
)

func GetDefaultModel() string {
return defaultModel
}

func SetDefaultModel(model string) {
defaultModel = model
}

func SetDefaults(tool types.Tool) types.Tool {
if tool.Parameters.ModelName == "" {
tool.Parameters.ModelName = DefaultModel
tool.Parameters.ModelName = GetDefaultModel()
}
return tool
}
12 changes: 12 additions & 0 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cache

import (
"context"
"errors"
"io/fs"
"os"
Expand Down Expand Up @@ -35,6 +36,17 @@ func complete(opts ...Options) (result Options) {
return
}

type noCacheKey struct{}

func IsNoCache(ctx context.Context) bool {
v, _ := ctx.Value(noCacheKey{}).(bool)
return v
}

func WithNoCache(ctx context.Context) context.Context {
return context.WithValue(ctx, noCacheKey{}, true)
}

func New(opts ...Options) (*Client, error) {
opt := complete(opts...)
if err := os.MkdirAll(opt.CacheDir, 0755); err != nil {
Expand Down
61 changes: 52 additions & 9 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"github.com/acorn-io/cmd"
"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/input"
"github.com/gptscript-ai/gptscript/pkg/llm"
"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/monitor"
"github.com/gptscript-ai/gptscript/pkg/mvl"
Expand All @@ -26,10 +28,13 @@ import (

type (
DisplayOptions monitor.Options
CacheOptions cache.Options
OpenAIOptions openai.Options
)

type GPTScript struct {
runner.Options
CacheOptions
OpenAIOptions
DisplayOptions
Debug bool `usage:"Enable debug logging"`
Quiet *bool `usage:"No output logging" short:"q"`
Expand All @@ -41,6 +46,8 @@ type GPTScript struct {
ListTools bool `usage:"List built-in tools and exit"`
Server bool `usage:"Start server"`
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:9090"`

_client llm.Client `usage:"-"`
}

func New() *cobra.Command {
Expand All @@ -67,6 +74,33 @@ func (r *GPTScript) Customize(cmd *cobra.Command) {
}
}

func (r *GPTScript) getClient(ctx context.Context) (llm.Client, error) {
if r._client != nil {
return r._client, nil
}

cacheClient, err := cache.New(cache.Options(r.CacheOptions))
if err != nil {
return nil, err
}

oaClient, err := openai.NewClient(openai.Options(r.OpenAIOptions), openai.Options{
Cache: cacheClient,
})
if err != nil {
return nil, err
}

registry := llm.NewRegistry()

if err := registry.AddClient(ctx, oaClient); err != nil {
return nil, err
}

r._client = registry
return r._client, nil
}

func (r *GPTScript) listTools() error {
var lines []string
for _, tool := range builtin.ListTools() {
Expand All @@ -77,12 +111,12 @@ func (r *GPTScript) listTools() error {
}

func (r *GPTScript) listModels(ctx context.Context) error {
c, err := openai.NewClient(openai.Options(r.OpenAIOptions))
c, err := r.getClient(ctx)
if err != nil {
return err
}

models, err := c.ListModules(ctx)
models, err := c.ListModels(ctx)
if err != nil {
return err
}
Expand All @@ -95,6 +129,10 @@ func (r *GPTScript) listModels(ctx context.Context) error {
}

func (r *GPTScript) Pre(*cobra.Command, []string) error {
if r.DefaultModel != "" {
builtin.SetDefaultModel(r.DefaultModel)
}

if r.Quiet == nil {
if term.IsTerminal(int(os.Stdout.Fd())) {
r.Quiet = new(bool)
Expand Down Expand Up @@ -126,9 +164,11 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
}

if r.Server {
s, err := server.New(server.Options{
CacheOptions: r.CacheOptions,
OpenAIOptions: r.OpenAIOptions,
c, err := r.getClient(cmd.Context())
if err != nil {
return err
}
s, err := server.New(c, server.Options{
ListenAddress: r.ListenAddress,
})
if err != nil {
Expand Down Expand Up @@ -176,9 +216,12 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
return assemble.Assemble(prg, out)
}

runner, err := runner.New(r.Options, runner.Options{
CacheOptions: r.CacheOptions,
OpenAIOptions: r.OpenAIOptions,
client, err := r.getClient(cmd.Context())
if err != nil {
return err
}

runner, err := runner.New(client, runner.Options{
MonitorFactory: monitor.NewConsole(monitor.Options(r.DisplayOptions), monitor.Options{
DisplayProgress: !*r.Quiet,
}),
Expand Down
7 changes: 3 additions & 4 deletions pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"sync/atomic"

"github.com/google/shlex"
"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
)
Expand All @@ -21,7 +20,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))

defer func() {
e.Progress <- openai.Status{
e.Progress <- types.CompletionStatus{
CompletionID: id,
Response: map[string]any{
"output": cmdOut,
Expand All @@ -31,7 +30,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
}()

if tool.BuiltinFunc != nil {
e.Progress <- openai.Status{
e.Progress <- types.CompletionStatus{
CompletionID: id,
Request: map[string]any{
"command": []string{tool.ID},
Expand All @@ -47,7 +46,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
}
defer stop()

e.Progress <- openai.Status{
e.Progress <- types.CompletionStatus{
CompletionID: id,
Request: map[string]any{
"command": cmd.Args,
Expand Down
75 changes: 19 additions & 56 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,16 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"sync"
"sync/atomic"

"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
)

// InternalSystemPrompt is added to all threads. Changing this is very dangerous as it has a
// terrible global effect and changes the behavior of all scripts.
var InternalSystemPrompt = `
You are task oriented system.
You receive input from a user, process the input from the given instructions, and then output the result.
Your objective is to provide consistent and correct results.
You do not need to explain the steps taken, only provide the result to the given instructions.
You are referred to as a tool.
`

var DefaultToolSchema = types.JSONSchema{
Property: types.Property{
Type: "object",
},
Properties: map[string]types.Property{
openai.DefaultPromptParameter: {
Description: "Prompt to send to the tool or assistant. This may be instructions or question.",
Type: "string",
},
},
Required: []string{openai.DefaultPromptParameter},
}

var completionID int64

func init() {
if p := os.Getenv("GPTSCRIPT_INTERNAL_SYSTEM_PROMPT"); p != "" {
InternalSystemPrompt = p
}
}

type ErrToolNotFound struct {
ToolName string
}
Expand All @@ -52,10 +22,14 @@ func (e *ErrToolNotFound) Error() string {
return fmt.Sprintf("tool not found: %s", e.ToolName)
}

type Model interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
}

type Engine struct {
Client *openai.Client
Model Model
Env []string
Progress chan<- openai.Status
Progress chan<- types.CompletionStatus
}

type State struct {
Expand Down Expand Up @@ -172,18 +146,12 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
}

completion := types.CompletionRequest{
Model: tool.Parameters.ModelName,
MaxToken: tool.Parameters.MaxTokens,
JSONResponse: tool.Parameters.JSONResponse,
Cache: tool.Parameters.Cache,
Temperature: tool.Parameters.Temperature,
}

if InternalSystemPrompt != "" && (tool.Parameters.InternalPrompt == nil || *tool.Parameters.InternalPrompt) {
completion.Messages = append(completion.Messages, types.CompletionMessage{
Role: types.CompletionMessageRoleTypeSystem,
Content: types.Text(InternalSystemPrompt),
})
Model: tool.Parameters.ModelName,
MaxTokens: tool.Parameters.MaxTokens,
JSONResponse: tool.Parameters.JSONResponse,
Cache: tool.Parameters.Cache,
Temperature: tool.Parameters.Temperature,
InternalSystemPrompt: tool.Parameters.InternalPrompt,
}

for _, subToolName := range tool.Parameters.Tools {
Expand All @@ -193,10 +161,9 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
}
args := subTool.Parameters.Arguments
if args == nil && !subTool.IsCommand() {
args = &DefaultToolSchema
args = &system.DefaultToolSchema
}
completion.Tools = append(completion.Tools, types.CompletionTool{
Type: types.CompletionToolTypeFunction,
Function: types.CompletionFunctionDefinition{
Name: subToolName,
Description: subTool.Parameters.Description,
Expand All @@ -207,12 +174,8 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {

if tool.Instructions != "" {
completion.Messages = append(completion.Messages, types.CompletionMessage{
Role: types.CompletionMessageRoleTypeSystem,
Content: []types.ContentPart{
{
Text: tool.Instructions,
},
},
Role: types.CompletionMessageRoleTypeSystem,
Content: types.Text(tool.Instructions),
})
}

Expand All @@ -230,7 +193,7 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {

func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
var (
progress = make(chan openai.Status)
progress = make(chan types.CompletionStatus)
ret = Return{
State: state,
Calls: map[string]Call{},
Expand All @@ -241,6 +204,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
// ensure we aren't writing to the channel anymore on exit
wg.Add(1)
defer wg.Wait()
defer close(progress)

go func() {
defer wg.Done()
Expand All @@ -251,8 +215,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
}
}()

resp, err := e.Client.Call(ctx, state.Completion, progress)
close(progress)
resp, err := e.Model.Call(ctx, state.Completion, progress)
if err != nil {
return nil, err
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/hash/sha256.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ import (
"encoding/json"
)

func Digest(obj any) string {
data, err := json.Marshal(obj)
if err != nil {
panic(err)
}

hash := sha256.Sum224(data)
return hex.EncodeToString(hash[:])
}

func Encode(obj any) string {
data, err := json.Marshal(obj)
if err != nil {
Expand Down
Loading

0 comments on commit 046d340

Please sign in to comment.