Skip to content

Commit

Permalink
add SchemaRaw option for strict json response format
Browse files Browse the repository at this point in the history
  • Loading branch information
h0rv committed Aug 10, 2024
1 parent 1880333 commit 03508ae
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 1 deletion.
62 changes: 62 additions & 0 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,65 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
}
}
}

func TestChatCompletionResponseFormat_JSONSchemaRaw(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := openai.NewClient(apiToken)
ctx := context.Background()

schema := []byte(`{
"type": "object",
"properties": {
"CamelCase": {"type": "string"},
"KebabCase": {"type": "string"},
"PascalCase": {"type": "string"},
"SnakeCase": {"type": "string"}
},
"required": ["PascalCase", "CamelCase", "KebabCase", "SnakeCase"],
"additionalProperties": false
}`)

resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
"2. CamelCase: The first word starts with a lowercase letter, " +
"and subsequent words start with an uppercase letter, with no spaces or separators." +
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello World",
},
},
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
SchemaRaw: &schema,
Strict: true,
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
}
}
}
25 changes: 24 additions & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,33 @@ type ChatCompletionResponseFormat struct {
type ChatCompletionResponseFormatJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema jsonschema.Definition `json:"schema"`
Schema jsonschema.Definition `json:"-"`
SchemaRaw *[]byte `json:"-"`
Strict bool `json:"strict"`
}

func (c *ChatCompletionResponseFormatJSONSchema) MarshalJSON() ([]byte, error) {
type Alias ChatCompletionResponseFormatJSONSchema // prevent recursive marshalling
var data struct {
*Alias
Schema interface{} `json:"schema,omitempty"`
}

data.Alias = (*Alias)(c)

if c.SchemaRaw != nil {
var rawSchema interface{}
if err := json.Unmarshal(*c.SchemaRaw, &rawSchema); err != nil {
return nil, err
}
data.Schema = rawSchema
} else {
data.Schema = c.Schema
}

return json.Marshal(data)
}

// ChatCompletionRequest represents a request structure for chat completion API.
type ChatCompletionRequest struct {
Model string `json:"model"`
Expand Down
62 changes: 62 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,65 @@ func TestFinishReason(t *testing.T) {
}
}
}

func TestChatCompletionResponseFormatJSONSchemaMarshalJSON(t *testing.T) {
tests := []struct {
name string
input openai.ChatCompletionResponseFormatJSONSchema
expected string
wantErr bool
}{
{
name: "Empty Schema and SchemaRaw",
input: openai.ChatCompletionResponseFormatJSONSchema{
Name: "TestName",
SchemaRaw: nil,
Strict: false,
},
expected: `{"name":"TestName","strict":false,"schema":{}}`,
wantErr: false,
},
{
name: "Non-empty SchemaRaw",
input: openai.ChatCompletionResponseFormatJSONSchema{
Name: "TestName",
SchemaRaw: func() *[]byte { b := []byte(`{"key":"value"}`); return &b }(),
Strict: true,
},
expected: `{"name":"TestName","strict":true,"schema":{"key":"value"}}`,
wantErr: false,
},
{
name: "Invalid SchemaRaw JSON",
input: openai.ChatCompletionResponseFormatJSONSchema{
Name: "TestName",
SchemaRaw: func() *[]byte { b := []byte(`{key:value}`); return &b }(),
Strict: true,
},
expected: "",
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.input.MarshalJSON()

if (err != nil) != tt.wantErr {
t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}

if tt.wantErr {
if len(got) != 0 {
t.Errorf("Expected empty output on error, got: %s", string(got))
}
return
}

if string(got) != tt.expected {
t.Errorf("MarshalJSON() got = %s, expected %s", string(got), tt.expected)
}
})
}
}

0 comments on commit 03508ae

Please sign in to comment.