diff --git a/chat.go b/chat.go index 9adf2808..08abdeab 100644 --- a/chat.go +++ b/chat.go @@ -209,6 +209,7 @@ type ChatCompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` @@ -219,7 +220,8 @@ type ChatCompletionRequest struct { ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` Seed *int `json:"seed,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + + // LogitBias i s must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias LogitBias map[string]int `json:"logit_bias,omitempty"` diff --git a/client.go b/client.go index 9f547e7c..583244fe 100644 --- a/client.go +++ b/client.go @@ -285,10 +285,18 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } + return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) + } var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ + HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, } @@ -298,6 +306,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error { return reqErr } + errRes.Error.HTTPStatus = resp.Status errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } diff --git a/client_test.go b/client_test.go index 3f27b9dd..18da787a 100644 --- a/client_test.go +++ b/client_test.go @@ -134,14 +134,17 @@ func TestHandleErrorResp(t *testing.T) { client := NewClient(mockToken) testCases := []struct { - name string - httpCode int - body io.Reader - expected string + name string + httpCode int + httpStatus string + contentType string + body io.Reader + expected string }{ { - name: "401 Invalid Authentication", - httpCode: http.StatusUnauthorized, + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -152,11 +155,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: You didn't provide an API key. ....", + expected: "error, status code: 401, status: , message: You didn't provide an API key. ....", }, { - name: "401 Azure Access Denied", - httpCode: http.StatusUnauthorized, + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -165,11 +169,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.", }, { - name: "503 Model Overloaded", - httpCode: http.StatusServiceUnavailable, + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{ @@ -179,22 +184,53 @@ func TestHandleErrorResp(t *testing.T) { "code":null } }`)), - expected: "error, status code: 503, message: That model...", + expected: "error, status code: 503, status: , message: That model...", }, { - name: "503 no message (Unknown response)", - httpCode: http.StatusServiceUnavailable, + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{} }`)), - expected: "error, status code: 503, message: ", + expected: "error, status code: 503, status: , message: ", + }, + { + name: "413 Request Entity Too Large", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: bytes.NewReader([]byte(` +