Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/system prompts when empty #392

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions api/pkg/controller/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ func (c *Controller) ChatCompletion(ctx context.Context, user *types.User, req o
}
}

if assistant.SystemPrompt != "" && len(req.Messages) >= 1 && req.Messages[0].Role == openai.ChatMessageRoleSystem {
req.Messages[0].Content = assistant.SystemPrompt
}
req = setSystemPrompt(&req, assistant.SystemPrompt)

if assistant.Model != "" {
req.Model = assistant.Model
Expand Down Expand Up @@ -104,9 +102,7 @@ func (c *Controller) ChatCompletionStream(ctx context.Context, user *types.User,
}
}

if assistant.SystemPrompt != "" && len(req.Messages) >= 1 && req.Messages[0].Role == openai.ChatMessageRoleSystem {
req.Messages[0].Content = assistant.SystemPrompt
}
req = setSystemPrompt(&req, assistant.SystemPrompt)

if assistant.Model != "" {
req.Model = assistant.Model
Expand Down Expand Up @@ -315,3 +311,35 @@ func extendMessageWithRAGResults(req *openai.ChatCompletionRequest, ragResults [

return nil
}

// setSystemPrompt if the assistant has a system prompt, set it in the request. If there is already
// provided system prompt, overwrite it and if there is no system prompt, set it as the first message
func setSystemPrompt(req *openai.ChatCompletionRequest, systemPrompt string) openai.ChatCompletionRequest {
if systemPrompt == "" {
// Nothing to do
return *req
}

if len(req.Messages) == 0 {
req.Messages = append(req.Messages, openai.ChatCompletionMessage{
Role: "system",
Content: systemPrompt,
})
}

if len(req.Messages) >= 1 && req.Messages[0].Role == openai.ChatMessageRoleSystem {
req.Messages[0].Content = systemPrompt
}

// If first message is not a system message, add it as the first message
if len(req.Messages) >= 1 && req.Messages[0].Role != openai.ChatMessageRoleSystem {
req.Messages = append([]openai.ChatCompletionMessage{
{
Role: "system",
Content: systemPrompt,
},
}, req.Messages...)
}

return *req
}
92 changes: 92 additions & 0 deletions api/pkg/controller/inference_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package controller

import (
"reflect"
"testing"

openai "github.com/lukemarsden/go-openai2"
)

func Test_setSystemPrompt(t *testing.T) {
type args struct {
req *openai.ChatCompletionRequest
systemPrompt string
}
tests := []struct {
name string
args args
want openai.ChatCompletionRequest
}{
{
name: "No system prompt set and no messages",
args: args{
req: &openai.ChatCompletionRequest{},
systemPrompt: "",
},
want: openai.ChatCompletionRequest{},
},
{
name: "System prompt set and message user only",
args: args{
req: &openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello",
},
},
},
systemPrompt: "You are a helpful assistant.",
},
want: openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "You are a helpful assistant.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello",
},
},
},
},
{
name: "System prompt is set and request messages has system prompt",
args: args{
req: &openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "Original system prompt",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello",
},
},
},
systemPrompt: "New system prompt",
},
want: openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "New system prompt",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := setSystemPrompt(tt.args.req, tt.args.systemPrompt); !reflect.DeepEqual(got, tt.want) {
t.Errorf("setSystemPrompt() = %v, want %v", got, tt.want)
}
})
}
}