diff --git a/README.md b/README.md index 799dc602b..0d6aafa40 100644 --- a/README.md +++ b/README.md @@ -743,6 +743,70 @@ func main() { } ``` + +
+Structured Outputs + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + type Result struct { + Steps []struct { + Explanation string `json:"explanation"` + Output string `json:"output"` + } `json:"steps"` + FinalAnswer string `json:"final_answer"` + } + var result Result + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + log.Fatalf("GenerateSchemaForType error: %v", err) + } + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are a helpful math tutor. Guide the user through the solution step by step.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "how can I solve 8x + 7 = -23", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "math_reasoning", + Schema: schema, + Strict: true, + }, + }, + }) + if err != nil { + log.Fatalf("CreateChatCompletion error: %v", err) + } + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + if err != nil { + log.Fatalf("Unmarshal schema error: %v", err) + } + fmt.Println(result) +} +``` +
See the `examples/` folder for more. ## Frequently Asked Questions diff --git a/api_integration_test.go b/api_integration_test.go index 57f7c40fb..8c9f3384f 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,7 +4,6 @@ package openai_test import ( "context" - "encoding/json" "errors" "io" "os" @@ -190,6 +189,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { c := openai.NewClient(apiToken) ctx := context.Background() + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + var result MyStructuredResponse + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } resp, err := c.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ @@ -212,31 +222,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ - Name: "cases", - Schema: jsonschema.Definition{ - Type: jsonschema.Object, - Properties: map[string]jsonschema.Definition{ - "PascalCase": jsonschema.Definition{Type: jsonschema.String}, - "CamelCase": jsonschema.Definition{Type: jsonschema.String}, - "KebabCase": jsonschema.Definition{Type: jsonschema.String}, - "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, - }, - Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, - AdditionalProperties: false, - }, + Name: "cases", + Schema: schema, Strict: true, }, }, }, ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") - var result = make(map[string]string) - err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) - checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") - for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { - if _, ok := result[key]; !ok { - t.Errorf("key:%s does not exist.", key) - } + if err == nil { + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } } diff --git a/chat.go b/chat.go index 97c89a497..56e99a78b 100644 --- a/chat.go +++ b/chat.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "net/http" - - "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct { } type ChatCompletionResponseFormatJSONSchema struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Schema jsonschema.Definition `json:"schema"` - Strict bool `json:"strict"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/example_test.go b/example_test.go index 1bdb8496e..e5dbf44bf 100644 --- a/example_test.go +++ b/example_test.go @@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() { } defer stream.Close() - fmt.Printf("Stream response: ") + fmt.Print("Stream response: ") for { var response openai.ChatCompletionStreamResponse response, err = stream.Recv() diff --git a/jsonschema/json.go b/jsonschema/json.go index 7fd1e11bf..bcb253fae 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,7 +4,13 @@ // and/or pass in the schema in []byte format. package jsonschema -import "encoding/json" +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) type DataType string @@ -42,7 +48,7 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` } -func (d Definition) MarshalJSON() ([]byte, error) { +func (d *Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) } @@ -50,6 +56,99 @@ func (d Definition) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Alias }{ - Alias: (Alias)(d), + Alias: (Alias)(*d), }) } + +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) +} + +func GenerateSchemaForType(v any) (*Definition, error) { + return reflectSchema(reflect.TypeOf(v)) +} + +func reflectSchema(t reflect.Type) (*Definition, error) { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d.Items = items + case reflect.Struct: + d.Type = Object + d.AdditionalProperties = false + object, err := reflectSchemaObject(t) + if err != nil { + return nil, err + } + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + if jsonTag == "" { + jsonTag = field.Name + } else if strings.HasSuffix(jsonTag, ",omitempty") { + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + return &d, nil +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..f14ffd4c4 --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,89 @@ +package jsonschema + +import ( + "encoding/json" + "errors" +) + +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data) { + return errors.New("data validation failed against the provided schema") + } + return json.Unmarshal(content, &v) +} + +func Validate(schema Definition, data any) bool { + switch schema.Type { + case Object: + return validateObject(schema, data) + case Array: + return validateArray(schema, data) + case String: + _, ok := data.(string) + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + return false + } +} + +func validateObject(schema Definition, data any) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value) { + return false + } else if !exists && contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any) bool { + dataArray, ok := data.([]any) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item) { + return false + } + } + return true +} + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go new file mode 100644 index 000000000..c2c47a2ce --- /dev/null +++ b/jsonschema/validate_test.go @@ -0,0 +1,136 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +func Test_Validate(t *testing.T) { + type args struct { + data any + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want bool + }{ + // string integer number boolean + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, + {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, + {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, + {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, + {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, + // array + {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, true}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, false}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, true}, + {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, false}, + // object + {"", args{data: map[string]any{ + "string": "abc", + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, true}, + {"", args{data: map[string]any{ + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { + t.Errorf("Validate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshal(t *testing.T) { + type args struct { + schema jsonschema.Definition + content []byte + v any + } + var result1 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + var result2 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + }, + content: []byte(`{"string":"abc","number":123.4}`), + v: &result1, + }, false}, + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + Required: []string{"string", "number"}, + }, + content: []byte(`{"string":"abc"}`), + v: result2, + }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil { + t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +}