Skip to content

Commit

Permalink
Add support for multi part chat messages (and gpt-4-vision-preview mo…
Browse files Browse the repository at this point in the history
…del) (sashabaranov#580)

* Add support for multi part chat messages

OpenAI has recently introduced a new model called gpt-4-visual-preview,
which now supports images as input. The chat completion endpoint accepts
multi-part chat messages, where the content can be an array of structs
in addition to the usual string format.

This commit introduces new structures and constants to represent
different types of content parts. It also implements the json.Marshaler
and json.Unmarshaler interfaces on ChatCompletionMessage.

* Add ImageURLDetail and ChatMessagePartType types

* Optimize ChatCompletionMessage deserialization

* Add ErrContentFieldsMisused error
  • Loading branch information
rkintzi authored Nov 24, 2023
1 parent 7260991 commit 03caea8
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 2 deletions.
91 changes: 89 additions & 2 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai

import (
"context"
"encoding/json"
"errors"
"net/http"
)
Expand All @@ -20,6 +21,7 @@ const chatCompletionsSuffix = "/chat/completions"
var (
ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously")
)

type Hate struct {
Expand Down Expand Up @@ -51,9 +53,36 @@ type PromptAnnotation struct {
ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"`
}

type ImageURLDetail string

const (
ImageURLDetailHigh ImageURLDetail = "high"
ImageURLDetailLow ImageURLDetail = "low"
ImageURLDetailAuto ImageURLDetail = "auto"
)

type ChatMessageImageURL struct {
URL string `json:"url,omitempty"`
Detail ImageURLDetail `json:"detail,omitempty"`
}

type ChatMessagePartType string

const (
ChatMessagePartTypeText ChatMessagePartType = "text"
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
)

type ChatMessagePart struct {
Type ChatMessagePartType `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
}

type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart

// This property isn't in the official documentation, but it's in
// the documentation for the official library for python:
Expand All @@ -70,6 +99,64 @@ type ChatCompletionMessage struct {
ToolCallID string `json:"tool_call_id,omitempty"`
}

func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
if m.Content != "" && m.MultiContent != nil {
return nil, ErrContentFieldsMisused
}
if len(m.MultiContent) > 0 {
msg := struct {
Role string `json:"role"`
Content string `json:"-"`
MultiContent []ChatMessagePart `json:"content,omitempty"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}(m)
return json.Marshal(msg)
}
msg := struct {
Role string `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}(m)
return json.Marshal(msg)
}

func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
msg := struct {
Role string `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}{}
if err := json.Unmarshal(bs, &msg); err == nil {
*m = ChatCompletionMessage(msg)
return nil
}
multiMsg := struct {
Role string `json:"role"`
Content string
MultiContent []ChatMessagePart `json:"content"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}{}
if err := json.Unmarshal(bs, &multiMsg); err != nil {
return err
}
*m = ChatCompletionMessage(multiMsg)
return nil
}

type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
Expand Down
103 changes: 103 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -296,6 +297,108 @@ func TestAzureChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateAzureChatCompletion error")
}

func TestMultipartChatCompletions(t *testing.T) {
client, server, teardown := setupAzureTestServer()
defer teardown()
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)

_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
MultiContent: []openai.ChatMessagePart{
{
Type: openai.ChatMessagePartTypeText,
Text: "Hello!",
},
{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: "URL",
Detail: openai.ImageURLDetailLow,
},
},
},
},
},
})
checks.NoError(t, err, "CreateAzureChatCompletion error")
}

func TestMultipartChatMessageSerialization(t *testing.T) {
jsonText := `[{"role":"system","content":"system-message"},` +
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`

var msgs []openai.ChatCompletionMessage
err := json.Unmarshal([]byte(jsonText), &msgs)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}
if len(msgs) != 2 {
t.Errorf("unexpected number of messages")
}
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
t.Errorf("invalid user message: %v", msgs[0])
}
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
t.Errorf("invalid user message")
}
parts := msgs[1].MultiContent
if parts[0].Type != "text" || parts[0].Text != "nice-text" {
t.Errorf("invalid text part: %v", parts[0])
}
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
t.Errorf("invalid image_url part")
}

s, err := json.Marshal(msgs)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}
res := strings.ReplaceAll(string(s), " ", "")
if res != jsonText {
t.Fatalf("invalid message: %s", string(s))
}

invalidMsg := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "some-text",
MultiContent: []openai.ChatMessagePart{
{
Type: "text",
Text: "nice-text",
},
},
},
}
_, err = json.Marshal(invalidMsg)
if !errors.Is(err, openai.ErrContentFieldsMisused) {
t.Fatalf("Expected error: %s", err)
}

err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs)
if err == nil {
t.Fatalf("Expected error")
}

emptyMultiContentMsg := openai.ChatCompletionMessage{
Role: "user",
MultiContent: []openai.ChatMessagePart{},
}
s, err = json.Marshal(emptyMultiContentMsg)
if err != nil {
t.Fatalf("Unexpected error")
}
res = strings.ReplaceAll(string(s), " ", "")
if res != `{"role":"user","content":""}` {
t.Fatalf("invalid message: %s", string(s))
}
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
Expand Down

0 comments on commit 03caea8

Please sign in to comment.