Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of JSON Schema in OpenAI API Response Context #819

Merged
merged 21 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package openai_test

import (
"context"
"encoding/json"
"errors"
"io"
"os"
Expand Down Expand Up @@ -190,6 +189,14 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
c := openai.NewClient(apiToken)
ctx := context.Background()

type MyStructuredResponse struct {
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"`
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
Keywords []string `json:"keywords" description:"Keywords" required:"true"`
}
schema := jsonschema.Warp(MyStructuredResponse{})
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff, I guess we should put a proper README example on how to use this. That should be one of the top examples for sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added an example use case from OpenAI documentation to the README. For reference, please see: https://platform.openai.com/docs/guides/structured-outputs/examples

resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Expand All @@ -211,31 +218,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
},
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": jsonschema.Definition{Type: jsonschema.String},
"CamelCase": jsonschema.Definition{Type: jsonschema.String},
"KebabCase": jsonschema.Definition{Type: jsonschema.String},
"SnakeCase": jsonschema.Definition{Type: jsonschema.String},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: schema,
Strict: true,
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
}
if err == nil {
_, err = schema.Unmarshal(resp.Choices[0].Message.Content)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
}
}
10 changes: 4 additions & 6 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"encoding/json"
"errors"
"net/http"

"github.com/sashabaranov/go-openai/jsonschema"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct {
}

type ChatCompletionResponseFormatJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema jsonschema.Definition `json:"schema"`
Strict bool `json:"strict"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema json.Marshaler `json:"schema"`
Strict bool `json:"strict"`
}

// ChatCompletionRequest represents a request structure for chat completion API.
Expand Down
87 changes: 86 additions & 1 deletion jsonschema/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
// and/or pass in the schema in []byte format.
package jsonschema

import "encoding/json"
import (
"encoding/json"
"reflect"
"strconv"
)

type DataType string

Expand Down Expand Up @@ -53,3 +57,84 @@
Alias: (Alias)(d),
})
}

type SchemaWrapper[T any] struct {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pardon my ignorance: do we even need this type?

Copy link
Contributor Author

@eiixy eiixy Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SchemaWrapper type is intended to bind the Definition and type together to ensure consistency during unmarshalling. This way, the schema’s type remains aligned with the initial definition, reducing the risk of errors or mismatches during data handling.

sw, err := jsonschema.Wrap(MyStructuredResponse{})
result, err := sw.Unmarshal(`{...}`)

or

var result MyStructuredResponse{}
schema, err := jsonschema.GenerateSchemaForType(result)
schema.Unmarshal(`{...}`, &result)

data T
schema Definition
}

func (r SchemaWrapper[T]) Schema() Definition {
return r.schema
}

func (r SchemaWrapper[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(r.schema)
}

func (r SchemaWrapper[T]) Unmarshal(content string) (*T, error) {
var v T
err := Unmarshal(r.schema, []byte(content), &v)
if err != nil {
return nil, err
}
return &v, nil
}

func (r SchemaWrapper[T]) String() string {
bytes, _ := json.MarshalIndent(r.schema, "", " ")
return string(bytes)
}

func Warp[T any](v T) SchemaWrapper[T] {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eiixy heads up: I think you meant Wrap instead of Warp here.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or even better: GenerateSchemaForType

return SchemaWrapper[T]{
data: v,
schema: reflectSchema(reflect.TypeOf(v)),
}
}

func reflectSchema(t reflect.Type) Definition {
var d Definition
switch t.Kind() {

Check failure on line 97 in jsonschema/json.go

View workflow job for this annotation

GitHub Actions / Sanity check

missing cases in switch of type reflect.Kind: reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer|reflect.Ptr, reflect.UnsafePointer (exhaustive)
case reflect.String:
d.Type = String
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
d.Type = Integer
case reflect.Float32, reflect.Float64:
d.Type = Number
case reflect.Bool:
d.Type = Boolean
case reflect.Slice, reflect.Array:
d.Type = Array
items := reflectSchema(t.Elem())
d.Items = &items
case reflect.Struct:
d.Type = Object
d.AdditionalProperties = false
properties := make(map[string]Definition)
var requiredFields []string
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
jsonTag := field.Tag.Get("json")
if jsonTag == "" {
jsonTag = field.Name
}

item := reflectSchema(field.Type)
description := field.Tag.Get("description")
if description != "" {
item.Description = description
}
properties[jsonTag] = item

required, _ := strconv.ParseBool(field.Tag.Get("required"))
if required {
requiredFields = append(requiredFields, jsonTag)
}
}
d.Required = requiredFields
d.Properties = properties
default:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we either return an error or panic here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion!

}
return d
}
67 changes: 67 additions & 0 deletions jsonschema/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,70 @@ func structToMap(t *testing.T, v any) map[string]any {
}
return got
}

type MyStructuredResponse struct {
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
KebabCase string `json:"kebab_case" required:"false" description:"KebabCase"`
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
Keywords []string `json:"keywords" description:"Keywords" required:"true"`
}

func TestWarp(t *testing.T) {
schemaStr := `{
"type": "object",
"properties": {
"camel_case": {
"type": "string",
"description": "CamelCase"
},
"kebab_case": {
"type": "string",
"description": "KebabCase"
},
"keywords": {
"type": "array",
"description": "Keywords",
"items": {
"type": "string"
}
},
"pascal_case": {
"type": "string",
"description": "PascalCase"
},
"snake_case": {
"type": "string",
"description": "SnakeCase"
}
},
"required": [
"pascal_case",
"camel_case",
"snake_case",
"keywords"
]
}`
schema := jsonschema.Warp(MyStructuredResponse{})
if schema.String() == schemaStr {
t.Errorf("Failed to Generate JSONSchema: schema = %s", schema)
}
}

func TestSchemaWrapper_Unmarshal(t *testing.T) {
schema := jsonschema.Warp(MyStructuredResponse{})
result, err := schema.Unmarshal(`{"pascal_case":"a","camel_case":"b","snake_case":"c","keywords":[]}`)
if err != nil {
t.Errorf("Failed to SchemaWrapper Unmarshal: error = %v", err)
} else {
var v = MyStructuredResponse{
PascalCase: "a",
CamelCase: "b",
SnakeCase: "c",
Keywords: []string{},
}
if !reflect.DeepEqual(*result, v) {
t.Errorf("Failed to SchemaWrapper Unmarshal: result = %v", *result)
}
}
}
Loading