From ed94de0067fa2a59d948fa73ed26799a48f3662e Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Tue, 27 Aug 2024 09:27:05 -0700 Subject: [PATCH] chore: dynamically update tools and other tool params on subsequent chats --- pkg/engine/engine.go | 49 ++++++---- pkg/tests/runner_test.go | 31 ++++++ .../TestToolsChange/call1-resp.golden | 9 ++ .../testdata/TestToolsChange/call1.golden | 70 ++++++++++++++ .../TestToolsChange/call2-resp.golden | 9 ++ .../testdata/TestToolsChange/call2.golden | 73 ++++++++++++++ .../testdata/TestToolsChange/step1.golden | 93 ++++++++++++++++++ .../testdata/TestToolsChange/step2.golden | 96 +++++++++++++++++++ 8 files changed, 410 insertions(+), 20 deletions(-) create mode 100644 pkg/tests/testdata/TestToolsChange/call1-resp.golden create mode 100644 pkg/tests/testdata/TestToolsChange/call1.golden create mode 100644 pkg/tests/testdata/TestToolsChange/call2-resp.golden create mode 100644 pkg/tests/testdata/TestToolsChange/call2.golden create mode 100644 pkg/tests/testdata/TestToolsChange/step1.golden create mode 100644 pkg/tests/testdata/TestToolsChange/step2.golden diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index f8fd8154..14b75e0a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -258,6 +258,29 @@ func (c *Context) WrappedContext(e *Engine) context.Context { return context.WithValue(c.Ctx, engineContext{}, &cp) } +func populateMessageParams(ctx Context, completion *types.CompletionRequest, tool types.Tool) error { + completion.Model = tool.Parameters.ModelName + completion.MaxTokens = tool.Parameters.MaxTokens + completion.JSONResponse = tool.Parameters.JSONResponse + completion.Cache = tool.Parameters.Cache + completion.Chat = tool.Parameters.Chat + completion.Temperature = tool.Parameters.Temperature + completion.InternalSystemPrompt = tool.Parameters.InternalPrompt + + if tool.Chat && completion.InternalSystemPrompt == nil { + completion.InternalSystemPrompt = new(bool) + } + + var err error + completion.Tools, err = tool.GetCompletionTools(*ctx.Program, ctx.AgentGroup...) + if err != nil { + return err + } + + completion.Messages = addUpdateSystem(ctx, tool, completion.Messages) + return nil +} + func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) { tool := ctx.Tool @@ -290,28 +313,11 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) { return nil, fmt.Errorf("credential tools cannot make calls to the LLM") } - completion := types.CompletionRequest{ - Model: tool.Parameters.ModelName, - MaxTokens: tool.Parameters.MaxTokens, - JSONResponse: tool.Parameters.JSONResponse, - Cache: tool.Parameters.Cache, - Chat: tool.Parameters.Chat, - Temperature: tool.Parameters.Temperature, - InternalSystemPrompt: tool.Parameters.InternalPrompt, - } - - if tool.Chat && completion.InternalSystemPrompt == nil { - completion.InternalSystemPrompt = new(bool) - } - - var err error - completion.Tools, err = tool.GetCompletionTools(*ctx.Program, ctx.AgentGroup...) - if err != nil { + var completion types.CompletionRequest + if err := populateMessageParams(ctx, &completion, tool); err != nil { return nil, err } - completion.Messages = addUpdateSystem(ctx, tool, completion.Messages) - if tool.Chat && input == "{}" { input = "" } @@ -497,6 +503,9 @@ func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Re return nil, fmt.Errorf("invalid continue call, no completion needed") } - state.Completion.Messages = addUpdateSystem(ctx, ctx.Tool, state.Completion.Messages) + if err := populateMessageParams(ctx, &state.Completion, ctx.Tool); err != nil { + return nil, err + } + return e.complete(ctx.Ctx, state) } diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 141e6aff..483b5b6f 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/tests/tester" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/hexops/autogold/v2" @@ -1041,3 +1042,33 @@ func TestRuntimesLocalDev(t *testing.T) { _ = os.RemoveAll("testdata/TestRuntimesLocalDev/node_modules") _ = os.RemoveAll("testdata/TestRuntimesLocalDev/package-lock.json") } + +func TestToolsChange(t *testing.T) { + r := tester.NewRunner(t) + prg, err := loader.ProgramFromSource(context.Background(), ` +chat: true +tools: sys.ls, sys.read, sys.write +`, "") + require.NoError(t, err) + + resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1") + require.NoError(t, err) + r.AssertResponded(t) + assert.False(t, resp.Done) + autogold.Expect("TEST RESULT CALL: 1").Equal(t, resp.Content) + autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1")) + + prg, err = loader.ProgramFromSource(context.Background(), ` +chat: true +temperature: 0.6 +tools: sys.ls, sys.write +`, "") + require.NoError(t, err) + + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "input 2") + require.NoError(t, err) + r.AssertResponded(t) + assert.False(t, resp.Done) + autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp.Content) + autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2")) +} diff --git a/pkg/tests/testdata/TestToolsChange/call1-resp.golden b/pkg/tests/testdata/TestToolsChange/call1-resp.golden new file mode 100644 index 00000000..2861a036 --- /dev/null +++ b/pkg/tests/testdata/TestToolsChange/call1-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestToolsChange/call1.golden b/pkg/tests/testdata/TestToolsChange/call1.golden new file mode 100644 index 00000000..6c7c2d55 --- /dev/null +++ b/pkg/tests/testdata/TestToolsChange/call1.golden @@ -0,0 +1,70 @@ +`{ + "model": "gpt-4o", + "internalSystemPrompt": false, + "tools": [ + { + "function": { + "toolID": "sys.ls", + "name": "ls", + "description": "Lists the contents of a directory", + "parameters": { + "properties": { + "dir": { + "description": "The directory to list", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "sys.read", + "name": "read", + "description": "Reads the contents of a file", + "parameters": { + "properties": { + "filename": { + "description": "The name of the file to read", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "sys.write", + "name": "write", + "description": "Write the contents to a file", + "parameters": { + "properties": { + "content": { + "description": "The content to write", + "type": "string" + }, + "filename": { + "description": "The name of the file to write to", + "type": "string" + } + }, + "type": "object" + } + } + } + ], + "messages": [ + { + "role": "user", + "content": [ + { + "text": "input 1" + } + ], + "usage": {} + } + ], + "chat": true +}` diff --git a/pkg/tests/testdata/TestToolsChange/call2-resp.golden b/pkg/tests/testdata/TestToolsChange/call2-resp.golden new file mode 100644 index 00000000..997ca1b9 --- /dev/null +++ b/pkg/tests/testdata/TestToolsChange/call2-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 2" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestToolsChange/call2.golden b/pkg/tests/testdata/TestToolsChange/call2.golden new file mode 100644 index 00000000..ad86b7ce --- /dev/null +++ b/pkg/tests/testdata/TestToolsChange/call2.golden @@ -0,0 +1,73 @@ +`{ + "model": "gpt-4o", + "internalSystemPrompt": false, + "tools": [ + { + "function": { + "toolID": "sys.ls", + "name": "ls", + "description": "Lists the contents of a directory", + "parameters": { + "properties": { + "dir": { + "description": "The directory to list", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "sys.write", + "name": "write", + "description": "Write the contents to a file", + "parameters": { + "properties": { + "content": { + "description": "The content to write", + "type": "string" + }, + "filename": { + "description": "The name of the file to write to", + "type": "string" + } + }, + "type": "object" + } + } + } + ], + "messages": [ + { + "role": "user", + "content": [ + { + "text": "input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "input 2" + } + ], + "usage": {} + } + ], + "chat": true, + "temperature": 0.6 +}` diff --git a/pkg/tests/testdata/TestToolsChange/step1.golden b/pkg/tests/testdata/TestToolsChange/step1.golden new file mode 100644 index 00000000..1aae05d1 --- /dev/null +++ b/pkg/tests/testdata/TestToolsChange/step1.golden @@ -0,0 +1,93 @@ +`{ + "done": false, + "content": "TEST RESULT CALL: 1", + "toolID": "inline:", + "state": { + "continuation": { + "state": { + "input": "input 1", + "completion": { + "model": "gpt-4o", + "internalSystemPrompt": false, + "tools": [ + { + "function": { + "toolID": "sys.ls", + "name": "ls", + "description": "Lists the contents of a directory", + "parameters": { + "properties": { + "dir": { + "description": "The directory to list", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "sys.read", + "name": "read", + "description": "Reads the contents of a file", + "parameters": { + "properties": { + "filename": { + "description": "The name of the file to read", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "sys.write", + "name": "write", + "description": "Write the contents to a file", + "parameters": { + "properties": { + "content": { + "description": "The content to write", + "type": "string" + }, + "filename": { + "description": "The name of the file to write to", + "type": "string" + } + }, + "type": "object" + } + } + } + ], + "messages": [ + { + "role": "user", + "content": [ + { + "text": "input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} + } + ], + "chat": true + } + }, + "result": "TEST RESULT CALL: 1" + }, + "continuationToolID": "inline:" + } +}` diff --git a/pkg/tests/testdata/TestToolsChange/step2.golden b/pkg/tests/testdata/TestToolsChange/step2.golden new file mode 100644 index 00000000..9c9dbad7 --- /dev/null +++ b/pkg/tests/testdata/TestToolsChange/step2.golden @@ -0,0 +1,96 @@ +`{ + "done": false, + "content": "TEST RESULT CALL: 2", + "toolID": "inline:", + "state": { + "continuation": { + "state": { + "input": "input 1", + "completion": { + "model": "gpt-4o", + "internalSystemPrompt": false, + "tools": [ + { + "function": { + "toolID": "sys.ls", + "name": "ls", + "description": "Lists the contents of a directory", + "parameters": { + "properties": { + "dir": { + "description": "The directory to list", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "sys.write", + "name": "write", + "description": "Write the contents to a file", + "parameters": { + "properties": { + "content": { + "description": "The content to write", + "type": "string" + }, + "filename": { + "description": "The name of the file to write to", + "type": "string" + } + }, + "type": "object" + } + } + } + ], + "messages": [ + { + "role": "user", + "content": [ + { + "text": "input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "input 2" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 2" + } + ], + "usage": {} + } + ], + "chat": true, + "temperature": 0.6 + } + }, + "result": "TEST RESULT CALL: 2" + }, + "continuationToolID": "inline:" + } +}`