Skip to content

Commit

Permalink
Print status of subcalls
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Feb 10, 2024
1 parent a136e12 commit e28e7b7
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 37 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ require (
github.com/go-logr/logr v1.4.1 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/go-containerregistry v0.16.1 // indirect
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/hexops/valast v1.4.3 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-containerregistry v0.16.1 h1:rUEt426sR6nyrL3gt+18ibRcvYpKYdpsa5ZW7MA08dQ=
github.com/google/go-containerregistry v0.16.1/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY1hLbf8eeGapA+vbFDCtQ=
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
Expand Down
14 changes: 7 additions & 7 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type GPTScript struct {
runner.Options
DisplayOptions
Debug bool `usage:"Enable debug logging"`
Quiet bool `usage:"No output logging" short:"q"`
Quiet *bool `usage:"No output logging" short:"q"`
Output string `usage:"Save output to a file, or - for stdout" short:"o"`
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f"`
SubTool string `usage:"Use tool of this name, not the first tool in file"`
Expand Down Expand Up @@ -80,19 +80,19 @@ func (r *GPTScript) listModels(ctx context.Context) error {
}

func (r *GPTScript) Pre(cmd *cobra.Command, args []string) error {
if r.Quiet {
if r.Quiet == nil {
if term.IsTerminal(int(os.Stdout.Fd())) {
r.Quiet = false
r.Quiet = new(bool)
} else {
r.Quiet = true
r.Quiet = &[]bool{true}[0]
}
}

if r.Debug {
mvl.SetDebug()
} else {
mvl.SetSimpleFormat()
if r.Quiet {
if *r.Quiet {
mvl.SetError()
}
}
Expand Down Expand Up @@ -165,7 +165,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
CacheOptions: r.CacheOptions,
OpenAIOptions: r.OpenAIOptions,
MonitorFactory: monitor.NewConsole(monitor.Options(r.DisplayOptions), monitor.Options{
DisplayProgress: !r.Quiet,
DisplayProgress: !*r.Quiet,
}),
})
if err != nil {
Expand All @@ -188,7 +188,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
return err
}
} else {
if !r.Quiet {
if !*r.Quiet {
if toolInput != "" {
_, _ = fmt.Fprint(os.Stderr, "\nINPUT:\n\n")
_, _ = fmt.Fprintln(os.Stderr, toolInput)
Expand Down
108 changes: 98 additions & 10 deletions pkg/monitor/display.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ type display struct {
}

type livePrinter struct {
lastLines map[string]string
needsNewline bool
lastContent map[string]string
callIDMap map[string]string
activePrinters []string
toPrint []string
needsNewline bool
}

func (l *livePrinter) end() {
Expand All @@ -71,21 +74,100 @@ func (l *livePrinter) end() {
_, _ = fmt.Fprintln(os.Stderr)
}
l.needsNewline = false
if len(l.activePrinters) > 0 {
delete(l.lastContent, l.activePrinters[0])
}
}

func (l *livePrinter) print(event runner.Event, c call) {
func (l *livePrinter) progressStart(event runner.Event, c call) {
if l == nil {
return
}
if !slices.Contains(l.activePrinters, c.ID) {
l.activePrinters = append(l.activePrinters, c.ID)
}
l.toPrint = slices.DeleteFunc(l.toPrint, func(s string) bool {
return s == c.ID
})
}

func (l *livePrinter) progressEnd(event runner.Event, c call) {
if l == nil {
return
}
if c.ParentID != "" {
var result []string
for i, id := range l.activePrinters {
if id != c.ID {
result = append(result, id)
continue
}

if i != 0 {
if !slices.Contains(l.toPrint, id) {
l.toPrint = append(l.toPrint, id)
}
continue
}

for _, toPrintID := range l.toPrint {
content := l.lastContent[toPrintID]
delete(l.lastContent, toPrintID)
if content != "" {
_, _ = fmt.Fprint(os.Stderr, content)
if !strings.HasSuffix(content, "\n") {
_, _ = fmt.Fprintln(os.Stderr)
}
}
}

l.toPrint = nil
result = l.activePrinters[1:]
if len(result) > 0 {
content := l.lastContent[result[0]]
if content != "" {
_, _ = fmt.Fprint(os.Stderr, content)
l.needsNewline = !strings.HasSuffix(content, "\n")
}
}
break
}
l.activePrinters = result
}

func (l *livePrinter) formatContent(event runner.Event, c call) string {
if event.Content == "" {
return event.Content
}
prefix := fmt.Sprintf(" content [%s] content | ", l.callIDMap[c.ID])
var lines []string
for _, line := range strings.Split(event.Content, "\n") {
if len(line) > 100 {
line = line[:100] + " ..."
}
lines = append(lines, prefix+line)
}
return strings.Join(lines, "\n")
}

func (l *livePrinter) print(event runner.Event, c call) {
if l == nil {
return
}

last := l.lastLines[c.ID]
line := strings.TrimPrefix(event.Content, last)
_, _ = fmt.Fprint(os.Stderr, line)
l.needsNewline = !strings.HasSuffix(line, "\n")
l.lastLines[c.ID] = event.Content
content := l.formatContent(event, c)
last := l.lastContent[c.ID]
l.lastContent[c.ID] = content

if len(l.activePrinters) > 0 && l.activePrinters[0] == c.ID && content != "" {
line, ok := strings.CutPrefix(content, last)
if !ok && last != "" {
_, _ = fmt.Fprintln(os.Stderr)
}
if line != "" {
_, _ = fmt.Fprint(os.Stderr, line)
l.needsNewline = !strings.HasSuffix(line, "\n")
}
}
}

func (d *display) Event(event runner.Event) {
Expand Down Expand Up @@ -135,13 +217,17 @@ func (d *display) Event(event runner.Event) {

switch event.Type {
case runner.EventTypeCallStart:
d.livePrinter.progressStart(event, currentCall)
d.livePrinter.end()
currentCall.Start = event.Time
currentCall.Input = event.Content
log.Fields("input", event.Content).Infof("started [%s]", callName)
case runner.EventTypeCallSubCalls:
d.livePrinter.progressEnd(event, currentCall)
case runner.EventTypeCallProgress:
d.livePrinter.print(event, currentCall)
case runner.EventTypeCallContinue:
d.livePrinter.progressStart(event, currentCall)
d.livePrinter.end()
log.Fields("toolResults", event.ToolResults).Infof("continue [%s]", callName)
case runner.EventTypeChat:
Expand All @@ -167,6 +253,7 @@ func (d *display) Event(event runner.Event) {
Cached: event.ChatResponseCached,
})
case runner.EventTypeCallFinish:
d.livePrinter.progressEnd(event, currentCall)
d.livePrinter.end()
currentCall.End = event.Time
currentCall.Output = event.Content
Expand Down Expand Up @@ -204,7 +291,8 @@ func newDisplay(dumpState string, progress bool) *display {
}
if progress {
display.livePrinter = &livePrinter{
lastLines: map[string]string{},
lastContent: map[string]string{},
callIDMap: display.callIDMap,
}
}
return display
Expand Down
10 changes: 1 addition & 9 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,19 +407,11 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
cacheKey := c.cacheKey(request)
request.Stream = true

msg := ""
if len(request.Messages) > 0 {
msg = request.Messages[len(request.Messages)-1].Content
if msg != "" {
msg = "Sent content:\n\n" + msg + "\n"
}
}

partial <- Status{
CompletionID: transactionID,
PartialResponse: &types.CompletionMessage{
Role: types.CompletionMessageRoleTypeAssistant,
Content: types.Text(msg + "Waiting for model response...\n"),
Content: types.Text("Waiting for model response..."),
},
}

Expand Down
27 changes: 18 additions & 9 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,24 @@ func (r *Runner) Run(ctx context.Context, prg types.Program, env []string, input
}

type Event struct {
Time time.Time `json:"time,omitempty"`
CallContext *engine.Context `json:"callContext,omitempty"`
ToolResults int `json:"toolResults,omitempty"`
Type EventType `json:"type,omitempty"`
ChatCompletionID string `json:"chatCompletionId,omitempty"`
ChatRequest any `json:"chatRequest,omitempty"`
ChatResponse any `json:"chatResponse,omitempty"`
ChatResponseCached bool `json:"chatResponseCached,omitempty"`
Content string `json:"content,omitempty"`
Time time.Time `json:"time,omitempty"`
CallContext *engine.Context `json:"callContext,omitempty"`
ToolSubCalls map[string]engine.Call `json:"toolSubCalls,omitempty"`
ToolResults int `json:"toolResults,omitempty"`
Type EventType `json:"type,omitempty"`
ChatCompletionID string `json:"chatCompletionId,omitempty"`
ChatRequest any `json:"chatRequest,omitempty"`
ChatResponse any `json:"chatResponse,omitempty"`
ChatResponseCached bool `json:"chatResponseCached,omitempty"`
Content string `json:"content,omitempty"`
}

type EventType string

var (
EventTypeCallStart = EventType("callStart")
EventTypeCallContinue = EventType("callContinue")
EventTypeCallSubCalls = EventType("callSubCalls")
EventTypeCallProgress = EventType("callProgress")
EventTypeChat = EventType("callChat")
EventTypeCallFinish = EventType("callFinish")
Expand Down Expand Up @@ -138,6 +140,13 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp
return *result.Result, nil
}

monitor.Event(Event{
Time: time.Now(),
CallContext: &callCtx,
Type: EventTypeCallSubCalls,
ToolSubCalls: result.Calls,
})

callResults, err := r.subCalls(callCtx, monitor, env, result)
if err != nil {
return "", err
Expand Down

0 comments on commit e28e7b7

Please sign in to comment.