From 47856de229ca71edb9528ba2c91c183aa990b094 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 00:12:55 +0800 Subject: [PATCH 01/19] feat: add jsonschema.Validate and jsonschema.Unmarshal --- jsonschema/validate.go | 73 +++++++++++++++++++++ jsonschema/validate_test.go | 125 ++++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 jsonschema/validate.go create mode 100644 jsonschema/validate_test.go diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..1608375ae --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,73 @@ +package jsonschema + +import ( + "encoding/json" + "errors" + "slices" +) + +func Unmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data) { + return errors.New("validate failed") + } + return json.Unmarshal(content, &v) +} + +func Validate(schema Definition, data interface{}) bool { + switch schema.Type { + case Object: + dataMap, ok := data.(map[string]interface{}) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value) { + return false + } else if !exists && slices.Contains(schema.Required, key) { + return false + } + } + return true + case Array: + dataArray, ok := data.([]interface{}) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item) { + return false + } + } + return true + case String: + _, ok := data.(string) + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + return false + } +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go new file mode 100644 index 000000000..a7d18a093 --- /dev/null +++ b/jsonschema/validate_test.go @@ -0,0 +1,125 @@ +package jsonschema + +import ( + "testing" +) + +func Test_Validate(t *testing.T) { + type args struct { + data interface{} + schema Definition + } + tests := []struct { + name string + args args + want bool + }{ + // string integer number boolean + {"", args{data: "ABC", schema: Definition{Type: String}}, true}, + {"", args{data: 123, schema: Definition{Type: String}}, false}, + {"", args{data: 123, schema: Definition{Type: Integer}}, true}, + {"", args{data: 123.4, schema: Definition{Type: Integer}}, false}, + {"", args{data: "ABC", schema: Definition{Type: Number}}, false}, + {"", args{data: 123, schema: Definition{Type: Number}}, true}, + {"", args{data: false, schema: Definition{Type: Boolean}}, true}, + {"", args{data: 123, schema: Definition{Type: Boolean}}, false}, + {"", args{data: nil, schema: Definition{Type: Null}}, true}, + {"", args{data: 0, schema: Definition{Type: Null}}, false}, + // array + {"", args{data: []any{"a", "b", "c"}, schema: Definition{Type: Array, Items: &Definition{Type: String}}}, true}, + {"", args{data: []any{1, 2, 3}, schema: Definition{Type: Array, Items: &Definition{Type: String}}}, false}, + {"", args{data: []any{1, 2, 3}, schema: Definition{Type: Array, Items: &Definition{Type: Integer}}}, true}, + {"", args{data: []any{1, 2, 3.4}, schema: Definition{Type: Array, Items: &Definition{Type: Integer}}}, false}, + // object + {"", args{data: map[string]any{ + "string": "abc", + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: Definition{Type: Object, Properties: map[string]Definition{ + "string": {Type: String}, + "integer": {Type: Integer}, + "number": {Type: Number}, + "boolean": {Type: Boolean}, + "array": {Type: Array, Items: &Definition{Type: Number}}, + }, + Required: []string{"string"}, + }}, true}, + {"", args{data: map[string]any{ + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: Definition{Type: Object, Properties: map[string]Definition{ + "string": {Type: String}, + "integer": {Type: Integer}, + "number": {Type: Number}, + "boolean": {Type: Boolean}, + "array": {Type: Array, Items: &Definition{Type: Number}}, + }, + Required: []string{"string"}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Validate(tt.args.schema, tt.args.data); got != tt.want { + t.Errorf("Validate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshal(t *testing.T) { + type args struct { + schema Definition + content []byte + v any + } + var result1 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + var result2 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"", args{ + schema: Definition{ + Type: Object, + Properties: map[string]Definition{ + "string": {Type: String}, + "number": {Type: Number}, + }, + }, + content: []byte(`{"string":"abc","number":123.4}`), + v: &result1, + }, false}, + {"", args{ + schema: Definition{ + Type: Object, + Properties: map[string]Definition{ + "string": {Type: String}, + "number": {Type: Number}, + }, + Required: []string{"string", "number"}, + }, + content: []byte(`{"string":"abc"}`), + v: result2, + }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Unmarshal(tt.args.schema, tt.args.content, tt.args.v); (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil { + t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} From 0e894072fce3b85f520e4662ad7a6665f8fdfb61 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 00:22:18 +0800 Subject: [PATCH 02/19] fix Sanity check --- jsonschema/validate.go | 64 ++++++++++++++------------ jsonschema/validate_test.go | 91 ++++++++++++++++++++----------------- 2 files changed, 86 insertions(+), 69 deletions(-) diff --git a/jsonschema/validate.go b/jsonschema/validate.go index 1608375ae..af884333b 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -21,35 +21,9 @@ func Unmarshal(schema Definition, content []byte, v any) error { func Validate(schema Definition, data interface{}) bool { switch schema.Type { case Object: - dataMap, ok := data.(map[string]interface{}) - if !ok { - return false - } - for _, field := range schema.Required { - if _, exists := dataMap[field]; !exists { - return false - } - } - for key, valueSchema := range schema.Properties { - value, exists := dataMap[key] - if exists && !Validate(valueSchema, value) { - return false - } else if !exists && slices.Contains(schema.Required, key) { - return false - } - } - return true + return validateObject(schema, data) case Array: - dataArray, ok := data.([]interface{}) - if !ok { - return false - } - for _, item := range dataArray { - if !Validate(*schema.Items, item) { - return false - } - } - return true + return validateArray(schema, data) case String: _, ok := data.(string) return ok @@ -71,3 +45,37 @@ func Validate(schema Definition, data interface{}) bool { return false } } + +func validateObject(schema Definition, data any) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value) { + return false + } else if !exists && slices.Contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any) bool { + dataArray, ok := data.([]interface{}) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item) { + return false + } + } + return true +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index a7d18a093..9f8928039 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -1,13 +1,14 @@ -package jsonschema +package jsonschema_test import ( + "github.com/sashabaranov/go-openai/jsonschema" "testing" ) func Test_Validate(t *testing.T) { type args struct { data interface{} - schema Definition + schema jsonschema.Definition } tests := []struct { name string @@ -15,21 +16,29 @@ func Test_Validate(t *testing.T) { want bool }{ // string integer number boolean - {"", args{data: "ABC", schema: Definition{Type: String}}, true}, - {"", args{data: 123, schema: Definition{Type: String}}, false}, - {"", args{data: 123, schema: Definition{Type: Integer}}, true}, - {"", args{data: 123.4, schema: Definition{Type: Integer}}, false}, - {"", args{data: "ABC", schema: Definition{Type: Number}}, false}, - {"", args{data: 123, schema: Definition{Type: Number}}, true}, - {"", args{data: false, schema: Definition{Type: Boolean}}, true}, - {"", args{data: 123, schema: Definition{Type: Boolean}}, false}, - {"", args{data: nil, schema: Definition{Type: Null}}, true}, - {"", args{data: 0, schema: Definition{Type: Null}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, + {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, + {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, + {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, + {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, // array - {"", args{data: []any{"a", "b", "c"}, schema: Definition{Type: Array, Items: &Definition{Type: String}}}, true}, - {"", args{data: []any{1, 2, 3}, schema: Definition{Type: Array, Items: &Definition{Type: String}}}, false}, - {"", args{data: []any{1, 2, 3}, schema: Definition{Type: Array, Items: &Definition{Type: Integer}}}, true}, - {"", args{data: []any{1, 2, 3.4}, schema: Definition{Type: Array, Items: &Definition{Type: Integer}}}, false}, + {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, true}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, false}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, true}, + {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, false}, // object {"", args{data: map[string]any{ "string": "abc", @@ -37,12 +46,12 @@ func Test_Validate(t *testing.T) { "number": 123.4, "boolean": false, "array": []any{1, 2, 3}, - }, schema: Definition{Type: Object, Properties: map[string]Definition{ - "string": {Type: String}, - "integer": {Type: Integer}, - "number": {Type: Number}, - "boolean": {Type: Boolean}, - "array": {Type: Array, Items: &Definition{Type: Number}}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, }, Required: []string{"string"}, }}, true}, @@ -51,19 +60,19 @@ func Test_Validate(t *testing.T) { "number": 123.4, "boolean": false, "array": []any{1, 2, 3}, - }, schema: Definition{Type: Object, Properties: map[string]Definition{ - "string": {Type: String}, - "integer": {Type: Integer}, - "number": {Type: Number}, - "boolean": {Type: Boolean}, - "array": {Type: Array, Items: &Definition{Type: Number}}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, }, Required: []string{"string"}, }}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := Validate(tt.args.schema, tt.args.data); got != tt.want { + if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { t.Errorf("Validate() = %v, want %v", got, tt.want) } }) @@ -72,7 +81,7 @@ func Test_Validate(t *testing.T) { func TestUnmarshal(t *testing.T) { type args struct { - schema Definition + schema jsonschema.Definition content []byte v any } @@ -90,22 +99,22 @@ func TestUnmarshal(t *testing.T) { wantErr bool }{ {"", args{ - schema: Definition{ - Type: Object, - Properties: map[string]Definition{ - "string": {Type: String}, - "number": {Type: Number}, + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, }, }, content: []byte(`{"string":"abc","number":123.4}`), v: &result1, }, false}, {"", args{ - schema: Definition{ - Type: Object, - Properties: map[string]Definition{ - "string": {Type: String}, - "number": {Type: Number}, + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, }, Required: []string{"string", "number"}, }, @@ -115,7 +124,7 @@ func TestUnmarshal(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := Unmarshal(tt.args.schema, tt.args.content, tt.args.v); (err != nil) != tt.wantErr { + if err := jsonschema.Unmarshal(tt.args.schema, tt.args.content, tt.args.v); (err != nil) != tt.wantErr { t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } else if err == nil { t.Logf("Unmarshal() v = %+v\n", tt.args.v) From 545ee4b600299711ba47f75bde9272f28399e96b Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 00:25:52 +0800 Subject: [PATCH 03/19] remove slices.Contains --- jsonschema/validate.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jsonschema/validate.go b/jsonschema/validate.go index af884333b..63902356b 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -3,7 +3,6 @@ package jsonschema import ( "encoding/json" "errors" - "slices" ) func Unmarshal(schema Definition, content []byte, v any) error { @@ -60,7 +59,7 @@ func validateObject(schema Definition, data any) bool { value, exists := dataMap[key] if exists && !Validate(valueSchema, value) { return false - } else if !exists && slices.Contains(schema.Required, key) { + } else if !exists && contains(schema.Required, key) { return false } } @@ -79,3 +78,12 @@ func validateArray(schema Definition, data any) bool { } return true } + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} From 8bb6204b38cdc0320ad050109062074f682b5373 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 00:27:23 +0800 Subject: [PATCH 04/19] fix Sanity check --- jsonschema/validate_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index 9f8928039..595fb3997 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -1,8 +1,9 @@ package jsonschema_test import ( - "github.com/sashabaranov/go-openai/jsonschema" "testing" + + "github.com/sashabaranov/go-openai/jsonschema" ) func Test_Validate(t *testing.T) { From 8b49a3652c807ccf7232865c296f2f0e2d56b290 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 14:59:18 +0800 Subject: [PATCH 05/19] add SchemaWrapper --- api_integration_test.go | 35 +++++++---------- chat.go | 10 ++--- jsonschema/json.go | 87 ++++++++++++++++++++++++++++++++++++++++- jsonschema/json_test.go | 67 +++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 28 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index a487f588a..744c21a03 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,7 +4,6 @@ package openai_test import ( "context" - "encoding/json" "errors" "io" "os" @@ -190,6 +189,14 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { c := openai.NewClient(apiToken) ctx := context.Background() + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + Keywords []string `json:"keywords" description:"Keywords" required:"true"` + } + schema := jsonschema.Warp(MyStructuredResponse{}) resp, err := c.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ @@ -211,31 +218,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { }, ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, - JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{ - Name: "cases", - Schema: jsonschema.Definition{ - Type: jsonschema.Object, - Properties: map[string]jsonschema.Definition{ - "PascalCase": jsonschema.Definition{Type: jsonschema.String}, - "CamelCase": jsonschema.Definition{Type: jsonschema.String}, - "KebabCase": jsonschema.Definition{Type: jsonschema.String}, - "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, - }, - Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, - AdditionalProperties: false, - }, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "cases", + Schema: schema, Strict: true, }, }, }, ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") - var result = make(map[string]string) - err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) - checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") - for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { - if _, ok := result[key]; !ok { - t.Errorf("key:%s does not exist.", key) - } + if err == nil { + _, err = schema.Unmarshal(resp.Choices[0].Message.Content) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } } diff --git a/chat.go b/chat.go index 31fa887d6..83255b390 100644 --- a/chat.go +++ b/chat.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "net/http" - - "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct { } type ChatCompletionResponseFormatJSONSchema struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Schema jsonschema.Definition `json:"schema"` - Strict bool `json:"strict"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/jsonschema/json.go b/jsonschema/json.go index 7fd1e11bf..f41add9ee 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,7 +4,11 @@ // and/or pass in the schema in []byte format. package jsonschema -import "encoding/json" +import ( + "encoding/json" + "reflect" + "strconv" +) type DataType string @@ -53,3 +57,84 @@ func (d Definition) MarshalJSON() ([]byte, error) { Alias: (Alias)(d), }) } + +type SchemaWrapper[T any] struct { + data T + schema Definition +} + +func (r SchemaWrapper[T]) Schema() Definition { + return r.schema +} + +func (r SchemaWrapper[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(r.schema) +} + +func (r SchemaWrapper[T]) Unmarshal(content string) (*T, error) { + var v T + err := Unmarshal(r.schema, []byte(content), &v) + if err != nil { + return nil, err + } + return &v, nil +} + +func (r SchemaWrapper[T]) String() string { + bytes, _ := json.MarshalIndent(r.schema, "", " ") + return string(bytes) +} + +func Warp[T any](v T) SchemaWrapper[T] { + return SchemaWrapper[T]{ + data: v, + schema: reflectSchema(reflect.TypeOf(v)), + } +} + +func reflectSchema(t reflect.Type) Definition { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items := reflectSchema(t.Elem()) + d.Items = &items + case reflect.Struct: + d.Type = Object + d.AdditionalProperties = false + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + jsonTag = field.Name + } + + item := reflectSchema(field.Type) + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + properties[jsonTag] = item + + required, _ := strconv.ParseBool(field.Tag.Get("required")) + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + default: + } + return d +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 744706082..0b3b3805c 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -209,3 +209,70 @@ func structToMap(t *testing.T, v any) map[string]any { } return got } + +type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"false" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + Keywords []string `json:"keywords" description:"Keywords" required:"true"` +} + +func TestWarp(t *testing.T) { + schemaStr := `{ + "type": "object", + "properties": { + "camel_case": { + "type": "string", + "description": "CamelCase" + }, + "kebab_case": { + "type": "string", + "description": "KebabCase" + }, + "keywords": { + "type": "array", + "description": "Keywords", + "items": { + "type": "string" + } + }, + "pascal_case": { + "type": "string", + "description": "PascalCase" + }, + "snake_case": { + "type": "string", + "description": "SnakeCase" + } + }, + "required": [ + "pascal_case", + "camel_case", + "snake_case", + "keywords" + ] +}` + schema := jsonschema.Warp(MyStructuredResponse{}) + if schema.String() == schemaStr { + t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) + } +} + +func TestSchemaWrapper_Unmarshal(t *testing.T) { + schema := jsonschema.Warp(MyStructuredResponse{}) + result, err := schema.Unmarshal(`{"pascal_case":"a","camel_case":"b","snake_case":"c","keywords":[]}`) + if err != nil { + t.Errorf("Failed to SchemaWrapper Unmarshal: error = %v", err) + } else { + var v = MyStructuredResponse{ + PascalCase: "a", + CamelCase: "b", + SnakeCase: "c", + Keywords: []string{}, + } + if !reflect.DeepEqual(*result, v) { + t.Errorf("Failed to SchemaWrapper Unmarshal: result = %v", *result) + } + } +} From cdb007508c7064b60aea0456a4cd4871a6f48711 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 15:01:45 +0800 Subject: [PATCH 06/19] update api_integration_test.go --- api_integration_test.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 744c21a03..3785ce6ad 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -190,11 +190,10 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ctx := context.Background() type MyStructuredResponse struct { - PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` - CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` - KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` - SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` - Keywords []string `json:"keywords" description:"Keywords" required:"true"` + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` } schema := jsonschema.Warp(MyStructuredResponse{}) resp, err := c.CreateChatCompletion( From 1bf03a8b58e9f767d5b6579d3ce42d4fed9c18b7 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 16:03:23 +0800 Subject: [PATCH 07/19] update method 'reflectSchema' to support 'omitempty' in JSON tag --- jsonschema/json.go | 9 ++++++++- jsonschema/json_test.go | 18 +++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index f41add9ee..ccc560087 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -8,6 +8,7 @@ import ( "encoding/json" "reflect" "strconv" + "strings" ) type DataType string @@ -116,8 +117,12 @@ func reflectSchema(t reflect.Type) Definition { for i := 0; i < t.NumField(); i++ { field := t.Field(i) jsonTag := field.Tag.Get("json") + var required = true if jsonTag == "" { jsonTag = field.Name + } else if strings.HasSuffix(jsonTag, ",omitempty") { + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false } item := reflectSchema(field.Type) @@ -127,7 +132,9 @@ func reflectSchema(t reflect.Type) Definition { } properties[jsonTag] = item - required, _ := strconv.ParseBool(field.Tag.Get("required")) + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } if required { requiredFields = append(requiredFields, jsonTag) } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 0b3b3805c..cdef53273 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -211,11 +211,12 @@ func structToMap(t *testing.T, v any) map[string]any { } type MyStructuredResponse struct { - PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + PascalCase string `json:"pascal_case,omitempty" required:"true" description:"PascalCase"` CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` - KebabCase string `json:"kebab_case" required:"false" description:"KebabCase"` + KebabCase string `json:"kebab_case,omitempty" required:"false" description:"KebabCase"` SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` - Keywords []string `json:"keywords" description:"Keywords" required:"true"` + Keywords []string `json:"keywords,omitempty" description:"Keywords"` + Optional bool `json:"optional,omitempty"` } func TestWarp(t *testing.T) { @@ -237,6 +238,9 @@ func TestWarp(t *testing.T) { "type": "string" } }, + "optional": { + "type": "boolean" + }, "pascal_case": { "type": "string", "description": "PascalCase" @@ -249,12 +253,12 @@ func TestWarp(t *testing.T) { "required": [ "pascal_case", "camel_case", - "snake_case", - "keywords" - ] + "snake_case" + ], + "additionalProperties": false }` schema := jsonschema.Warp(MyStructuredResponse{}) - if schema.String() == schemaStr { + if schema.String() != schemaStr { t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) } } From d3fd653019dcdf897c0a8f2c335790b55203076d Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 10:39:38 +0800 Subject: [PATCH 08/19] add GenerateSchemaForType --- api_integration_test.go | 10 +++++++-- jsonschema/json.go | 46 +++++++++++++++++++++++++++-------------- jsonschema/validate.go | 2 +- 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 3785ce6ad..4fe476f95 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -195,7 +195,12 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` } - schema := jsonschema.Warp(MyStructuredResponse{}) + var result MyStructuredResponse + //schema := jsonschema.Warp(result) + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } resp, err := c.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ @@ -227,7 +232,8 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") if err == nil { - _, err = schema.Unmarshal(resp.Choices[0].Message.Content) + //_, err = schema.Unmarshal(resp.Choices[0].Message.Content) + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } } diff --git a/jsonschema/json.go b/jsonschema/json.go index ccc560087..eeea8dd27 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -6,6 +6,7 @@ package jsonschema import ( "encoding/json" + "fmt" "reflect" "strconv" "strings" @@ -59,13 +60,12 @@ func (d Definition) MarshalJSON() ([]byte, error) { }) } -type SchemaWrapper[T any] struct { - data T - schema Definition +func (d Definition) Unmarshal(content string, v any) error { + return Unmarshal(d, []byte(content), v) } -func (r SchemaWrapper[T]) Schema() Definition { - return r.schema +type SchemaWrapper[T any] struct { + schema *Definition } func (r SchemaWrapper[T]) MarshalJSON() ([]byte, error) { @@ -74,7 +74,7 @@ func (r SchemaWrapper[T]) MarshalJSON() ([]byte, error) { func (r SchemaWrapper[T]) Unmarshal(content string) (*T, error) { var v T - err := Unmarshal(r.schema, []byte(content), &v) + err := Unmarshal(*r.schema, []byte(content), &v) if err != nil { return nil, err } @@ -86,14 +86,21 @@ func (r SchemaWrapper[T]) String() string { return string(bytes) } -func Warp[T any](v T) SchemaWrapper[T] { - return SchemaWrapper[T]{ - data: v, - schema: reflectSchema(reflect.TypeOf(v)), +func Warp[T any](v T) (*SchemaWrapper[T], error) { + schema, err := reflectSchema(reflect.TypeOf(v)) + if err != nil { + return nil, err } + return &SchemaWrapper[T]{ + schema: schema, + }, nil } -func reflectSchema(t reflect.Type) Definition { +func GenerateSchemaForType(v any) (*Definition, error) { + return reflectSchema(reflect.TypeOf(v)) +} + +func reflectSchema(t reflect.Type) (*Definition, error) { var d Definition switch t.Kind() { case reflect.String: @@ -107,8 +114,11 @@ func reflectSchema(t reflect.Type) Definition { d.Type = Boolean case reflect.Slice, reflect.Array: d.Type = Array - items := reflectSchema(t.Elem()) - d.Items = &items + items, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d.Items = items case reflect.Struct: d.Type = Object d.AdditionalProperties = false @@ -125,12 +135,15 @@ func reflectSchema(t reflect.Type) Definition { required = false } - item := reflectSchema(field.Type) + item, err := reflectSchema(field.Type) + if err != nil { + return nil, err + } description := field.Tag.Get("description") if description != "" { item.Description = description } - properties[jsonTag] = item + properties[jsonTag] = *item if s := field.Tag.Get("required"); s != "" { required, _ = strconv.ParseBool(s) @@ -142,6 +155,7 @@ func reflectSchema(t reflect.Type) Definition { d.Required = requiredFields d.Properties = properties default: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) } - return d + return &d, nil } diff --git a/jsonschema/validate.go b/jsonschema/validate.go index 63902356b..d5724a1ee 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -12,7 +12,7 @@ func Unmarshal(schema Definition, content []byte, v any) error { return err } if !Validate(schema, data) { - return errors.New("validate failed") + return errors.New("data validation failed against the provided schema") } return json.Unmarshal(content, &v) } From 162bb6a179ffc531af7c2ba745121db1dbf272a1 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 10:43:28 +0800 Subject: [PATCH 09/19] update json_test.go --- jsonschema/json_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index cdef53273..8415005f9 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -257,14 +257,20 @@ func TestWarp(t *testing.T) { ], "additionalProperties": false }` - schema := jsonschema.Warp(MyStructuredResponse{}) + schema, err := jsonschema.Warp(MyStructuredResponse{}) + if err != nil { + t.Fatal(err) + } if schema.String() != schemaStr { t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) } } func TestSchemaWrapper_Unmarshal(t *testing.T) { - schema := jsonschema.Warp(MyStructuredResponse{}) + schema, err := jsonschema.Warp(MyStructuredResponse{}) + if err != nil { + t.Fatal(err) + } result, err := schema.Unmarshal(`{"pascal_case":"a","camel_case":"b","snake_case":"c","keywords":[]}`) if err != nil { t.Errorf("Failed to SchemaWrapper Unmarshal: error = %v", err) From 35d36caaff94b137be1179a33bde7f615c3b6d8c Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 10:56:24 +0800 Subject: [PATCH 10/19] update `Warp` to `Wrap` --- jsonschema/json.go | 2 +- jsonschema/json_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index eeea8dd27..e16854f61 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -86,7 +86,7 @@ func (r SchemaWrapper[T]) String() string { return string(bytes) } -func Warp[T any](v T) (*SchemaWrapper[T], error) { +func Wrap[T any](v T) (*SchemaWrapper[T], error) { schema, err := reflectSchema(reflect.TypeOf(v)) if err != nil { return nil, err diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 8415005f9..d9118a481 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -257,7 +257,7 @@ func TestWarp(t *testing.T) { ], "additionalProperties": false }` - schema, err := jsonschema.Warp(MyStructuredResponse{}) + schema, err := jsonschema.Wrap(MyStructuredResponse{}) if err != nil { t.Fatal(err) } @@ -267,7 +267,7 @@ func TestWarp(t *testing.T) { } func TestSchemaWrapper_Unmarshal(t *testing.T) { - schema, err := jsonschema.Warp(MyStructuredResponse{}) + schema, err := jsonschema.Wrap(MyStructuredResponse{}) if err != nil { t.Fatal(err) } From a80ea2f9d37b0ae14a33c5ccce384d564e7eb135 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 11:17:22 +0800 Subject: [PATCH 11/19] fix Sanity check --- jsonschema/json.go | 89 ++++++++++++++++++++++++++--------------- jsonschema/json_test.go | 63 ++++++++++++++++++++++++++++- 2 files changed, 119 insertions(+), 33 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index e16854f61..6e643e7d9 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -122,40 +122,65 @@ func reflectSchema(t reflect.Type) (*Definition, error) { case reflect.Struct: d.Type = Object d.AdditionalProperties = false - properties := make(map[string]Definition) - var requiredFields []string - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - jsonTag := field.Tag.Get("json") - var required = true - if jsonTag == "" { - jsonTag = field.Name - } else if strings.HasSuffix(jsonTag, ",omitempty") { - jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") - required = false - } - - item, err := reflectSchema(field.Type) - if err != nil { - return nil, err - } - description := field.Tag.Get("description") - if description != "" { - item.Description = description - } - properties[jsonTag] = *item - - if s := field.Tag.Get("required"); s != "" { - required, _ = strconv.ParseBool(s) - } - if required { - requiredFields = append(requiredFields, jsonTag) - } + object, err := reflectSchemaObject(t) + if err != nil { + return nil, err } - d.Required = requiredFields - d.Properties = properties - default: + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + if jsonTag == "" { + jsonTag = field.Name + } else if strings.HasSuffix(jsonTag, ",omitempty") { + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } } + d.Required = requiredFields + d.Properties = properties return &d, nil } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index d9118a481..abc2dfa8e 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -219,7 +219,7 @@ type MyStructuredResponse struct { Optional bool `json:"optional,omitempty"` } -func TestWarp(t *testing.T) { +func TestWrap(t *testing.T) { schemaStr := `{ "type": "object", "properties": { @@ -264,6 +264,67 @@ func TestWarp(t *testing.T) { if schema.String() != schemaStr { t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) } + type CustomStruct struct { + Title string `json:"title"` + Data *MyStructuredResponse `json:"data,omitempty"` + private string + } + schema2Str := `{ + "type": "object", + "properties": { + "data": { + "type": "object", + "properties": { + "camel_case": { + "type": "string", + "description": "CamelCase" + }, + "kebab_case": { + "type": "string", + "description": "KebabCase" + }, + "keywords": { + "type": "array", + "description": "Keywords", + "items": { + "type": "string" + } + }, + "optional": { + "type": "boolean" + }, + "pascal_case": { + "type": "string", + "description": "PascalCase" + }, + "snake_case": { + "type": "string", + "description": "SnakeCase" + } + }, + "required": [ + "pascal_case", + "camel_case", + "snake_case" + ], + "additionalProperties": false + }, + "title": { + "type": "string" + } + }, + "required": [ + "title" + ], + "additionalProperties": false +}` + schema2, err := jsonschema.Wrap(CustomStruct{}) + if err != nil { + t.Fatal(err) + } + if schema2.String() != schema2Str { + t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) + } } func TestSchemaWrapper_Unmarshal(t *testing.T) { From 5018f63480e247e51297cc0be0846c4e03eabeef Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 11:23:57 +0800 Subject: [PATCH 12/19] fix Sanity check --- jsonschema/json_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index abc2dfa8e..bc71bc47d 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -265,9 +265,8 @@ func TestWrap(t *testing.T) { t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) } type CustomStruct struct { - Title string `json:"title"` - Data *MyStructuredResponse `json:"data,omitempty"` - private string + Title string `json:"title"` + Data *MyStructuredResponse `json:"data,omitempty"` } schema2Str := `{ "type": "object", From 290bc29284bf65e879d76fee57f15c7d4aaaebef Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 11:46:17 +0800 Subject: [PATCH 13/19] update api_internal_test.go --- api_integration_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 4fe476f95..c1a6a58de 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -196,7 +196,7 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` } var result MyStructuredResponse - //schema := jsonschema.Warp(result) + //sw, err := jsonschema.Wrap(result) schema, err := jsonschema.GenerateSchemaForType(result) if err != nil { t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") @@ -223,7 +223,8 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ - Name: "cases", + Name: "cases", + //Schema: sw, Schema: schema, Strict: true, }, @@ -232,7 +233,7 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") if err == nil { - //_, err = schema.Unmarshal(resp.Choices[0].Message.Content) + //_, err = sw.Unmarshal(resp.Choices[0].Message.Content) err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } From 25d8769d52f7e8d4f9e38621f3a10a4f71e5c13a Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Tue, 13 Aug 2024 15:58:30 +0800 Subject: [PATCH 14/19] update README.md --- README.md | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/README.md b/README.md index 799dc602b..59adffa3c 100644 --- a/README.md +++ b/README.md @@ -743,6 +743,70 @@ func main() { } ``` + +
+Structured Outputs + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + type Result struct { + Steps []struct { + Explanation string `json:"explanation"` + Output string `json:"output"` + } + FinalAnswer string `json:"final_answer"` + } + var result Result + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + log.Fatalf("GenerateSchemaForType error: %v", err) + } + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are a helpful math tutor. Guide the user through the solution step by step.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "how can I solve 8x + 7 = -23", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "math_reasoning", + Schema: schema, + Strict: true, + }, + }, + }) + if err != nil { + log.Fatalf("CreateChatCompletion error: %v", err) + } + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + if err != nil { + log.Fatalf("Unmarshal schema error: %v", err) + } + fmt.Println(result) +} +``` +
See the `examples/` folder for more. ## Frequently Asked Questions From e21015fe20ff3892082f60d47802852ec6adafef Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Tue, 13 Aug 2024 16:08:10 +0800 Subject: [PATCH 15/19] update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 59adffa3c..0d6aafa40 100644 --- a/README.md +++ b/README.md @@ -767,7 +767,7 @@ func main() { Steps []struct { Explanation string `json:"explanation"` Output string `json:"output"` - } + } `json:"steps"` FinalAnswer string `json:"final_answer"` } var result Result From a4c1156c4f65bd1015cc11d6960158732cbcd348 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 14 Aug 2024 23:55:08 +0800 Subject: [PATCH 16/19] remove jsonschema.SchemaWrapper --- api_integration_test.go | 2 -- jsonschema/json.go | 32 -------------------------------- 2 files changed, 34 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index c1a6a58de..9d58b8fa3 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -196,7 +196,6 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` } var result MyStructuredResponse - //sw, err := jsonschema.Wrap(result) schema, err := jsonschema.GenerateSchemaForType(result) if err != nil { t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") @@ -233,7 +232,6 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") if err == nil { - //_, err = sw.Unmarshal(resp.Choices[0].Message.Content) err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } diff --git a/jsonschema/json.go b/jsonschema/json.go index 6e643e7d9..d5d291023 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -64,38 +64,6 @@ func (d Definition) Unmarshal(content string, v any) error { return Unmarshal(d, []byte(content), v) } -type SchemaWrapper[T any] struct { - schema *Definition -} - -func (r SchemaWrapper[T]) MarshalJSON() ([]byte, error) { - return json.Marshal(r.schema) -} - -func (r SchemaWrapper[T]) Unmarshal(content string) (*T, error) { - var v T - err := Unmarshal(*r.schema, []byte(content), &v) - if err != nil { - return nil, err - } - return &v, nil -} - -func (r SchemaWrapper[T]) String() string { - bytes, _ := json.MarshalIndent(r.schema, "", " ") - return string(bytes) -} - -func Wrap[T any](v T) (*SchemaWrapper[T], error) { - schema, err := reflectSchema(reflect.TypeOf(v)) - if err != nil { - return nil, err - } - return &SchemaWrapper[T]{ - schema: schema, - }, nil -} - func GenerateSchemaForType(v any) (*Definition, error) { return reflectSchema(reflect.TypeOf(v)) } From 4d0750d2b84040a0a373e0fd907838a882c3ea9b Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 14 Aug 2024 23:57:30 +0800 Subject: [PATCH 17/19] remove jsonschema.SchemaWrapper --- jsonschema/json_test.go | 137 ---------------------------------------- 1 file changed, 137 deletions(-) diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index bc71bc47d..744706082 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -209,140 +209,3 @@ func structToMap(t *testing.T, v any) map[string]any { } return got } - -type MyStructuredResponse struct { - PascalCase string `json:"pascal_case,omitempty" required:"true" description:"PascalCase"` - CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` - KebabCase string `json:"kebab_case,omitempty" required:"false" description:"KebabCase"` - SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` - Keywords []string `json:"keywords,omitempty" description:"Keywords"` - Optional bool `json:"optional,omitempty"` -} - -func TestWrap(t *testing.T) { - schemaStr := `{ - "type": "object", - "properties": { - "camel_case": { - "type": "string", - "description": "CamelCase" - }, - "kebab_case": { - "type": "string", - "description": "KebabCase" - }, - "keywords": { - "type": "array", - "description": "Keywords", - "items": { - "type": "string" - } - }, - "optional": { - "type": "boolean" - }, - "pascal_case": { - "type": "string", - "description": "PascalCase" - }, - "snake_case": { - "type": "string", - "description": "SnakeCase" - } - }, - "required": [ - "pascal_case", - "camel_case", - "snake_case" - ], - "additionalProperties": false -}` - schema, err := jsonschema.Wrap(MyStructuredResponse{}) - if err != nil { - t.Fatal(err) - } - if schema.String() != schemaStr { - t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) - } - type CustomStruct struct { - Title string `json:"title"` - Data *MyStructuredResponse `json:"data,omitempty"` - } - schema2Str := `{ - "type": "object", - "properties": { - "data": { - "type": "object", - "properties": { - "camel_case": { - "type": "string", - "description": "CamelCase" - }, - "kebab_case": { - "type": "string", - "description": "KebabCase" - }, - "keywords": { - "type": "array", - "description": "Keywords", - "items": { - "type": "string" - } - }, - "optional": { - "type": "boolean" - }, - "pascal_case": { - "type": "string", - "description": "PascalCase" - }, - "snake_case": { - "type": "string", - "description": "SnakeCase" - } - }, - "required": [ - "pascal_case", - "camel_case", - "snake_case" - ], - "additionalProperties": false - }, - "title": { - "type": "string" - } - }, - "required": [ - "title" - ], - "additionalProperties": false -}` - schema2, err := jsonschema.Wrap(CustomStruct{}) - if err != nil { - t.Fatal(err) - } - if schema2.String() != schema2Str { - t.Errorf("Failed to Generate JSONSchema: schema = %s", schema) - } -} - -func TestSchemaWrapper_Unmarshal(t *testing.T) { - schema, err := jsonschema.Wrap(MyStructuredResponse{}) - if err != nil { - t.Fatal(err) - } - result, err := schema.Unmarshal(`{"pascal_case":"a","camel_case":"b","snake_case":"c","keywords":[]}`) - if err != nil { - t.Errorf("Failed to SchemaWrapper Unmarshal: error = %v", err) - } else { - var v = MyStructuredResponse{ - PascalCase: "a", - CamelCase: "b", - SnakeCase: "c", - Keywords: []string{}, - } - if !reflect.DeepEqual(*result, v) { - t.Errorf("Failed to SchemaWrapper Unmarshal: result = %v", *result) - } - } -} From 9db4d84eb0e7cf46690d3f7fe26bf46d3fffffe4 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 15 Aug 2024 00:04:05 +0800 Subject: [PATCH 18/19] fix Sanity check --- example_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example_test.go b/example_test.go index de67c57cd..e5dbf44bf 100644 --- a/example_test.go +++ b/example_test.go @@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() { } defer stream.Close() - fmt.Printf("Stream response: ") + fmt.Print("Stream response: ") for { var response openai.ChatCompletionStreamResponse response, err = stream.Recv() @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } From 8680e7b60a54bb60e1d7ee8423d5958f6396a0aa Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sat, 24 Aug 2024 11:46:33 +0800 Subject: [PATCH 19/19] optimize code formatting --- api_integration_test.go | 3 +-- jsonschema/json.go | 8 ++++---- jsonschema/validate.go | 6 +++--- jsonschema/validate_test.go | 5 +++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 9d58b8fa3..3ce1f0755 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -222,8 +222,7 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ - Name: "cases", - //Schema: sw, + Name: "cases", Schema: schema, Strict: true, }, diff --git a/jsonschema/json.go b/jsonschema/json.go index d5d291023..bcb253fae 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -48,7 +48,7 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` } -func (d Definition) MarshalJSON() ([]byte, error) { +func (d *Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) } @@ -56,12 +56,12 @@ func (d Definition) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Alias }{ - Alias: (Alias)(d), + Alias: (Alias)(*d), }) } -func (d Definition) Unmarshal(content string, v any) error { - return Unmarshal(d, []byte(content), v) +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) } func GenerateSchemaForType(v any) (*Definition, error) { diff --git a/jsonschema/validate.go b/jsonschema/validate.go index d5724a1ee..f14ffd4c4 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -5,7 +5,7 @@ import ( "errors" ) -func Unmarshal(schema Definition, content []byte, v any) error { +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { var data any err := json.Unmarshal(content, &data) if err != nil { @@ -17,7 +17,7 @@ func Unmarshal(schema Definition, content []byte, v any) error { return json.Unmarshal(content, &v) } -func Validate(schema Definition, data interface{}) bool { +func Validate(schema Definition, data any) bool { switch schema.Type { case Object: return validateObject(schema, data) @@ -67,7 +67,7 @@ func validateObject(schema Definition, data any) bool { } func validateArray(schema Definition, data any) bool { - dataArray, ok := data.([]interface{}) + dataArray, ok := data.([]any) if !ok { return false } diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index 595fb3997..c2c47a2ce 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -8,7 +8,7 @@ import ( func Test_Validate(t *testing.T) { type args struct { - data interface{} + data any schema jsonschema.Definition } tests := []struct { @@ -125,7 +125,8 @@ func TestUnmarshal(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := jsonschema.Unmarshal(tt.args.schema, tt.args.content, tt.args.v); (err != nil) != tt.wantErr { + err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) + if (err != nil) != tt.wantErr { t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } else if err == nil { t.Logf("Unmarshal() v = %+v\n", tt.args.v)