Skip to content

Commit

Permalink
chore: Refactor CLI root and tool call sorting for code base improvem…
Browse files Browse the repository at this point in the history
…ents

- Renamed Root struct to GPTScript and added Output field to streamline CLI usage.
- Enhanced tool call sorting in the engine package for consistent ordering.
- Implemented seed generation method for API call consistency in the OpenAI client module.

This commit enhances the overall code quality and user experience by introducing better struct naming conventions, ensuring deterministic sorting of tool calls, and improving cache key generation for API responses.
  • Loading branch information
ibuildthecloud committed Jan 31, 2024
1 parent 6887aeb commit c0b8139
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
25 changes: 17 additions & 8 deletions pkg/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@ import (
"golang.org/x/term"
)

type Root struct {
type GPTScript struct {
runner.Options
Output string `usage:"Save output to a file" short:"o"`
}

func New() *cobra.Command {
return cmd.Command(&Root{})
return cmd.Command(&GPTScript{})
}

func (r *Root) Customize(cmd *cobra.Command) {
func (r *GPTScript) Customize(cmd *cobra.Command) {
cmd.Use = version.ProgramName
cmd.Args = cobra.MinimumNArgs(1)
cmd.Flags().SetInterspersed(false)
}

func (r *Root) Run(cmd *cobra.Command, args []string) error {
func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
in, err := os.Open(args[0])
if err != nil {
return err
Expand Down Expand Up @@ -55,9 +56,17 @@ func (r *Root) Run(cmd *cobra.Command, args []string) error {
return err
}

fmt.Print(s)
if !strings.HasSuffix(s, "\n") {
fmt.Println()
if r.Output != "" {
err = os.WriteFile(r.Output, []byte(s), 0644)
if err != nil {
return err
}
} else {
fmt.Print(s)
if !strings.HasSuffix(s, "\n") {
fmt.Println()
}
}
return err

return nil
}
21 changes: 14 additions & 7 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,12 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
}

var (
added bool
pendingIDs []string
added bool
pendingToolCalls []types.CompletionToolCall
)

for id, pending := range state.Pending {
pendingIDs = append(pendingIDs, id)
pendingToolCalls = append(pendingToolCalls, pending)
if _, ok := state.Results[id]; !ok {
ret.Calls[id] = Call{
ToolName: pending.Function.Name,
Expand All @@ -285,11 +285,18 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
return &ret, nil
}

sort.Strings(pendingIDs)
sort.Slice(pendingToolCalls, func(i, j int) bool {
left := pendingToolCalls[i].Function.Name + pendingToolCalls[i].Function.Arguments
right := pendingToolCalls[j].Function.Name + pendingToolCalls[j].Function.Arguments
if left == right {
return pendingToolCalls[i].ID < pendingToolCalls[j].ID
}
return left < right
})

for _, id := range pendingIDs {
pending := state.Pending[id]
if result, ok := state.Results[id]; ok {
for _, pending := range pendingToolCalls {
pending := pending
if result, ok := state.Results[pending.ID]; ok {
added = true
state.Completion.Messages = append(state.Completion.Messages, types.CompletionMessage{
Role: types.CompletionMessageRoleTypeTool,
Expand Down
21 changes: 20 additions & 1 deletion pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ func (c *Client) cacheKey(request openai.ChatCompletionRequest) string {
return hash.Encode(request)
}

func (c *Client) seed(request openai.ChatCompletionRequest) int {
newRequest := request
newRequest.Messages = nil

for _, msg := range request.Messages {
newMsg := msg
newMsg.ToolCalls = nil
newMsg.ToolCallID = ""

for _, tool := range msg.ToolCalls {
tool.ID = ""
newMsg.ToolCalls = append(newMsg.ToolCalls, tool)
}

newRequest.Messages = append(newRequest.Messages, newMsg)
}
return hash.Seed(newRequest)
}

func (c *Client) fromCache(ctx context.Context, messageRequest types.CompletionRequest, request openai.ChatCompletionRequest) (result []openai.ChatCompletionStreamResponse, _ bool, _ error) {
if messageRequest.Cache != nil && !*messageRequest.Cache {
return nil, false, nil
Expand Down Expand Up @@ -210,7 +229,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
}
}

request.Seed = ptr(hash.Seed(request))
request.Seed = ptr(c.seed(request))
response, ok, err := c.fromCache(ctx, messageRequest, request)
if err != nil {
return nil, err
Expand Down

0 comments on commit c0b8139

Please sign in to comment.