diff --git a/go.mod b/go.mod index 49acb47b..e71473ac 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.0 replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240206232711-45b6e096246a require ( + github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69 github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d github.com/acorn-io/cmd v0.0.0-20240203032901-e9e631185ddb github.com/adrg/xdg v0.4.0 diff --git a/go.sum b/go.sum index f8d6702e..f55d6b1e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69 h1:+tu3HOoMXB7RXEINRVIpxJCT+KdYiI7LAEAUrOw3dIU= +github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69/go.mod h1:L1AbZdiDllfyYH5l5OkAaZtk7VkWe89bPJFmnDBNHxg= github.com/acorn-io/baaah v0.0.0-20240119160309-2a58ee757bbd h1:Zbau2J6sEPl1H4gqnEx4/TI55eZncQR5cjfPOcG2lxE= github.com/acorn-io/baaah v0.0.0-20240119160309-2a58ee757bbd/go.mod h1:13nTO3svO8zTD3j9E5c86tCtK5YrKsK5sxca4Lwkbc0= github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d h1:hfpNQkJ4I2b8+DbMr8m97gG67ku0uPsMzUfskVu3cHU= diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index be14bf73..8d6882bd 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -15,6 +15,7 @@ import ( "sort" "strings" + "github.com/BurntSushi/locker" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/jaytaylor/html2text" ) @@ -242,6 +243,10 @@ func SysRead(ctx context.Context, env []string, input string) (string, error) { return "", err } + // Lock the file to prevent concurrent writes from other tool calls. + locker.RLock(params.Filename) + defer locker.RUnlock(params.Filename) + log.Debugf("Reading file %s", params.Filename) data, err := os.ReadFile(params.Filename) if err != nil { @@ -260,6 +265,10 @@ func SysWrite(ctx context.Context, env []string, input string) (string, error) { return "", err } + // Lock the file to prevent concurrent writes from other tool calls. + locker.Lock(params.Filename) + defer locker.Unlock(params.Filename) + data := []byte(params.Content) msg := fmt.Sprintf("Wrote %d bytes to file %s", len(data), params.Filename) log.Debugf(msg) @@ -276,6 +285,10 @@ func SysAppend(ctx context.Context, env []string, input string) (string, error) return "", err } + // Lock the file to prevent concurrent writes from other tool calls. + locker.Lock(params.Filename) + defer locker.Unlock(params.Filename) + f, err := os.OpenFile(params.Filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return "", err @@ -410,6 +423,10 @@ func SysRemove(ctx context.Context, env []string, input string) (string, error) return "", err } + // Lock the file to prevent concurrent writes from other tool calls. + locker.Lock(params.Location) + defer locker.Unlock(params.Location) + return fmt.Sprintf("Removed file: %s", params.Location), os.Remove(params.Location) }