Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
h0rv committed Aug 10, 2024
1 parent 440a5e0 commit 03b6379
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
12 changes: 11 additions & 1 deletion api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,17 @@ func TestChatCompletionResponseFormat_JSONSchemaRaw(t *testing.T) {
c := openai.NewClient(apiToken)
ctx := context.Background()

schema := []byte(`{"type":"object","properties":{"CamelCase":{"type":"string"},"KebabCase":{"type":"string"},"PascalCase":{"type":"string"},"SnakeCase":{"type":"string"}},"required":["PascalCase","CamelCase","KebabCase","SnakeCase"],"additionalProperties":false}`)
schema := []byte(`{
"type": "object",
"properties": {
"CamelCase": {"type": "string"},
"KebabCase": {"type": "string"},
"PascalCase": {"type": "string"},
"SnakeCase": {"type": "string"}
},
"required": ["PascalCase", "CamelCase", "KebabCase", "SnakeCase"],
"additionalProperties": false
}`)

resp, err := c.CreateChatCompletion(
ctx,
Expand Down
5 changes: 3 additions & 2 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,22 @@ type ChatCompletionResponseFormatJSONSchema struct {
}

func (c *ChatCompletionResponseFormatJSONSchema) MarshalJSON() ([]byte, error) {
type Alias ChatCompletionResponseFormatJSONSchema
type Alias ChatCompletionResponseFormatJSONSchema // prevent recursive marshalling
var data struct {
*Alias
Schema interface{} `json:"schema,omitempty"`
}

data.Alias = (*Alias)(c)

data.Schema = c.Schema
if c.SchemaRaw != nil {
var rawSchema interface{}
if err := json.Unmarshal(*c.SchemaRaw, &rawSchema); err != nil {
return nil, err
}
data.Schema = rawSchema
} else {
data.Schema = c.Schema
}

return json.Marshal(data)
Expand Down
17 changes: 13 additions & 4 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func TestFinishReason(t *testing.T) {
}
}

func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) {
func TestChatCompletionResponseFormatJSONSchemaMarshalJSON(t *testing.T) {
tests := []struct {
name string
input openai.ChatCompletionResponseFormatJSONSchema
Expand All @@ -542,7 +542,7 @@ func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) {
SchemaRaw: nil,
Strict: false,
},
expected: `{"name":"TestName","strict":false}`,
expected: `{"name":"TestName","strict":false,"schema":{}}`,
wantErr: false,
},
{
Expand All @@ -552,7 +552,7 @@ func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) {
SchemaRaw: func() *[]byte { b := []byte(`{"key":"value"}`); return &b }(),
Strict: true,
},
expected: `{"name":"TestName","schema":{"key":"value"},"strict":true}`,
expected: `{"name":"TestName","strict":true,"schema":{"key":"value"}}`,
wantErr: false,
},
{
Expand All @@ -570,12 +570,21 @@ func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.input.MarshalJSON()

if (err != nil) != tt.wantErr {
t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}

if tt.wantErr {
if len(got) != 0 {
t.Errorf("Expected empty output on error, got: %s", string(got))
}
return
}

if string(got) != tt.expected {
t.Errorf("MarshalJSON() got = %v, expected %v", string(got), tt.expected)
t.Errorf("MarshalJSON() got = %s, expected %s", string(got), tt.expected)
}
})
}
Expand Down

0 comments on commit 03b6379

Please sign in to comment.