Skip to content

Commit

Permalink
Workaround openai issue with temperature: 0 being omitted from request
Browse files Browse the repository at this point in the history
This PR adds a workaround for [this GitHub discussion][gh], describing
how chat completion requests which explicitly set temperature = 0 are
marshalled to JSON incorrectly (due to the 0 being indistinguishable
from the zero value of the field).

To do so we add a custom unmarshal method which checks whether the field
was defined and set to zero, and explicitly set the temperature to
`math.SmallestNonzeroFloat32` in such cases. This isn't perfect but is
likely to get very similar results in practice.

[gh]: sashabaranov/go-openai#9 (comment)
  • Loading branch information
sd2k committed May 10, 2024
1 parent d1c6507 commit 291497b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
27 changes: 27 additions & 0 deletions packages/grafana-llm-app/pkg/plugin/llm_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"math/rand"
"strings"

Expand Down Expand Up @@ -67,6 +68,32 @@ type ChatCompletionRequest struct {
Model Model `json:"model"`
}

// UnmarshalJSON implements json.Unmarshaler.
// We have a custom implementation here to check whether temperature is being
// explicitly set to `0` in the incoming request, because the `openai.ChatCompletionRequest`
// struct has `omitempty` on the Temperature field and would omit it when marshaling.
// If there is an explicit 0 value in the request, we set it to `math.SmallestNonzeroFloat32`,
// a workaround mentioned in https://github.com/sashabaranov/go-openai/issues/9#issuecomment-894845206.
func (c *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
// Create a wrapper type alias to avoid recursion, otherwise the
// subsequent call to UnmarshalJSON would call this method forever.
type Alias ChatCompletionRequest
var a Alias
if err := json.Unmarshal(data, &a); err != nil {
return err
}
// Also unmarshal to a map to check if temperature is being set explicitly in the request.
r := map[string]any{}
if err := json.Unmarshal(data, &r); err != nil {
return err
}
if t, ok := r["temperature"].(float64); ok && t == 0 {
a.ChatCompletionRequest.Temperature = math.SmallestNonzeroFloat32
}
*c = ChatCompletionRequest(a)
return nil
}

type ChatCompletionStreamResponse struct {
openai.ChatCompletionStreamResponse
// Random padding used to mitigate side channel attacks.
Expand Down
46 changes: 45 additions & 1 deletion packages/grafana-llm-app/pkg/plugin/llm_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package plugin

import (
"encoding/json"
"math"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestModelFromString(t *testing.T) {
Expand Down Expand Up @@ -83,7 +85,7 @@ func TestModelFromString(t *testing.T) {
}
}

func TestUnmarshalJSON(t *testing.T) {
func TestModelUnmarshalJSON(t *testing.T) {
tests := []struct {
input []byte
expected Model
Expand Down Expand Up @@ -164,6 +166,48 @@ func TestUnmarshalJSON(t *testing.T) {
}
}

func TestChatCompletionRequestUnmarshalJSON(t *testing.T) {
for _, tt := range []struct {
input []byte
expected ChatCompletionRequest
}{
{
input: []byte(`{"model":"base"}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: 0,
},
},
},
{
input: []byte(`{"model":"base", "temperature":0.5}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: 0.5,
},
},
},
{
input: []byte(`{"model":"base", "temperature":0}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: math.SmallestNonzeroFloat32,
},
},
},
} {
t.Run(string(tt.input), func(t *testing.T) {
var req ChatCompletionRequest
err := json.Unmarshal(tt.input, &req)
assert.NoError(t, err)
assert.Equal(t, tt.expected, req)
})
}
}

func TestChatCompletionStreamResponseMarshalJSON(t *testing.T) {
resp := ChatCompletionStreamResponse{
ChatCompletionStreamResponse: openai.ChatCompletionStreamResponse{
Expand Down

0 comments on commit 291497b

Please sign in to comment.