Skip to content

Commit

Permalink
[azopenai] Fixing some issues with incorrect/incomplete types in gene…
Browse files Browse the repository at this point in the history
…ration (#22119)

Fixes:
- ToolChoice was unmodeled.
- ResponseFormat for ChatCompletions wasn't settable using the swagger as we had it (it's an object, not a string)
  • Loading branch information
richardpark-msft committed Dec 8, 2023
1 parent e96bba7 commit 24e7b76
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 32 deletions.
2 changes: 1 addition & 1 deletion sdk/ai/azopenai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Release History

## 0.4.0 (2023-12-07)
## 0.4.0 (2023-12-11)

Support for many of the features mentioned in OpenAI's November Dev Day and Microsoft's 2023 Ignite conference

Expand Down
2 changes: 1 addition & 1 deletion sdk/ai/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/ai/azopenai",
"Tag": "go/ai/azopenai_9ed7d01267"
"Tag": "go/ai/azopenai_d4fd4783ec"
}
35 changes: 33 additions & 2 deletions sdk/ai/azopenai/autorest.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ These settings apply only when `--go` is specified on the command line.

``` yaml
input-file:
- https://github.com/Azure/azure-rest-api-specs/blob/d402f685809d6d08be9c0b45065cadd7d78ab870/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-12-01-preview/generated.json

- https://github.com/Azure/azure-rest-api-specs/blob/3e0e2a93ddb3c9c44ff1baf4952baa24ca98e9db/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-12-01-preview/generated.json
output-folder: ../azopenai
clear-output-folder: false
module: github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai
Expand Down Expand Up @@ -98,6 +97,20 @@ directive:
transform: return $.replace(/InternalOYDAuthTypeRename/g, "configType")
```

`ChatCompletionsResponseFormat.Type`

```yaml
directive:
- from: swagger-document
where: $.definitions.ChatCompletionsResponseFormat
transform: $.properties.type["x-ms-client-name"] = "InternalChatCompletionsResponseFormat"
- from:
- models.go
- models_serde.go
where: $
transform: return $.replace(/InternalChatCompletionsResponseFormat/g, "respType")
```

## Model -> DeploymentName

```yaml
Expand Down Expand Up @@ -571,3 +584,21 @@ directive:
return $.replace(/(func \(c ChatCompletionsOptions\) MarshalJSON\(\).+?populate\(objectMap, "frequency_penalty", c.FrequencyPenalty\))/s, "$1\n" + populateLines)
```

Fix ToolChoice discriminated union

```yaml
directive:
- from: swagger-document
where: $.definitions.ChatCompletionsOptions.properties
transform: $["tool_choice"]["x-ms-client-name"] = "ToolChoiceRenameMe"
- from:
- models.go
- models_serde.go
where: $
transform: |
return $
.replace(/^\s+ToolChoiceRenameMe.+$/m, "ToolChoice *ChatCompletionsToolChoice") // update the name _and_ type for the field
.replace(/ToolChoiceRenameMe/g, "ToolChoice") // rename all other references
.replace(/populateAny\(objectMap, "tool_choice", c\.ToolChoice\)/, 'populate(objectMap, "tool_choice", c.ToolChoice)'); // treat field as typed so nil means omit.
```
38 changes: 38 additions & 0 deletions sdk/ai/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azopenai_test

import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
Expand Down Expand Up @@ -262,3 +263,40 @@ func TestClient_OpenAI_GetChatCompletions_Vision(t *testing.T) {

t.Logf(*resp.Choices[0].Message.Content)
}

func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &deploymentName,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestSystemMessage{Content: to.Ptr("You are a helpful assistant designed to output JSON.")},
&azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent("List capital cities and their states"),
},
},
// Without this format directive you end up getting JSON, but with a non-JSON preamble, like this:
// "I'm happy to help! Here are some examples of capital cities and their corresponding states:\n\n```json\n{\n" (etc)
ResponseFormat: &azopenai.ChatCompletionsJSONResponseFormat{},
Temperature: to.Ptr[float32](0.0),
}

resp, err := chatClient.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)

// validate that it came back as JSON data
var v any
err = json.Unmarshal([]byte(*resp.Choices[0].Message.Content), &v)
require.NoError(t, err)
require.NotEmpty(t, v)
}

t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testFn(t, chatClient, "gpt-3.5-turbo-1106")
})

t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.DallE.Endpoint)
testFn(t, chatClient, "gpt-4-1106-preview")
})
}
40 changes: 36 additions & 4 deletions sdk/ai/azopenai/client_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,46 @@ type ParamProperty struct {
func TestGetChatCompletions_usingFunctions(t *testing.T) {
// https://platform.openai.com/docs/guides/gpt/function-calling

useSpecificTool := azopenai.NewChatCompletionsToolChoice(
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
)

t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testChatCompletionsFunctions(t, chatClient, openAI.ChatCompletions)
testChatCompletionsFunctions(t, chatClient, openAI.ChatCompletionsLegacyFunctions)

testData := []struct {
Model string
ToolChoice *azopenai.ChatCompletionsToolChoice
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{Model: openAI.ChatCompletions, ToolChoice: nil},
{Model: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
}

for _, td := range testData {
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
}
})

t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testChatCompletionsFunctions(t, chatClient, azureOpenAI.ChatCompletions)

testData := []struct {
Model string
ToolChoice *azopenai.ChatCompletionsToolChoice
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{Model: azureOpenAI.ChatCompletions, ToolChoice: nil},
{Model: azureOpenAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: azureOpenAI.ChatCompletions, ToolChoice: useSpecificTool},
}

for _, td := range testData {
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
}
})
}

Expand Down Expand Up @@ -120,7 +151,7 @@ func testChatCompletionsFunctionsOlderStyle(t *testing.T, client *azopenai.Clien
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}

func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string) {
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &deploymentName,
Messages: []azopenai.ChatRequestMessageClassification{
Expand Down Expand Up @@ -150,6 +181,7 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, dep
},
},
},
ToolChoice: toolChoice,
Temperature: to.Ptr[float32](0.0),
}

Expand Down
20 changes: 0 additions & 20 deletions sdk/ai/azopenai/constants.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions sdk/ai/azopenai/custom_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,56 @@ func (e *Error) Error() string {

return *e.message
}

// ChatCompletionsToolChoice controls which tool is used for this ChatCompletions call.
// You can choose between:
// - [ChatCompletionsToolChoiceAuto] means the model can pick between generating a message or calling a function.
// - [ChatCompletionsToolChoiceNone] means the model will not call a function and instead generates a message
// - Use the [NewChatCompletionsToolChoice] function to specify a specific tool.
type ChatCompletionsToolChoice struct {
value any
}

var (
// ChatCompletionsToolChoiceAuto means the model can pick between generating a message or calling a function.
ChatCompletionsToolChoiceAuto *ChatCompletionsToolChoice = &ChatCompletionsToolChoice{value: "auto"}

// ChatCompletionsToolChoiceNone means the model will not call a function and instead generates a message.
ChatCompletionsToolChoiceNone *ChatCompletionsToolChoice = &ChatCompletionsToolChoice{value: "none"}
)

// NewChatCompletionsToolChoice creates a ChatCompletionsToolChoice for a specific tool.
func NewChatCompletionsToolChoice[T ChatCompletionsToolChoiceFunction](v T) *ChatCompletionsToolChoice {
return &ChatCompletionsToolChoice{value: v}
}

// ChatCompletionsToolChoiceFunction can be used to force the model to call a particular function.
type ChatCompletionsToolChoiceFunction struct {
// Name is the name of the function to call.
Name string
}

// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsToolChoiceFunction.
func (tf ChatCompletionsToolChoiceFunction) MarshalJSON() ([]byte, error) {
type jsonInnerFunc struct {
Name string `json:"name"`
}

type jsonFormat struct {
Type string `json:"type"`
Function jsonInnerFunc `json:"function"`
}

return json.Marshal(jsonFormat{
Type: "function",
//nolint:gosimple,can't use the ChatCompletionsToolChoiceFunction here or marshalling will be circular!
Function: jsonInnerFunc{
Name: tf.Name,
},
})
}

// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsToolChoice.
func (tc ChatCompletionsToolChoice) MarshalJSON() ([]byte, error) {
return json.Marshal(tc.value)
}
9 changes: 9 additions & 0 deletions sdk/ai/azopenai/interfaces.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 42 additions & 2 deletions sdk/ai/azopenai/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 24e7b76

Please sign in to comment.