diff --git a/chat.go b/chat.go index 9adf2808..08abdeab 100644 --- a/chat.go +++ b/chat.go @@ -209,6 +209,7 @@ type ChatCompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` @@ -219,7 +220,8 @@ type ChatCompletionRequest struct { ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` Seed *int `json:"seed,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + + // LogitBias i s must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias LogitBias map[string]int `json:"logit_bias,omitempty"` diff --git a/client.go b/client.go index 9f547e7c..583244fe 100644 --- a/client.go +++ b/client.go @@ -285,10 +285,18 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } + return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) + } var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ + HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, } @@ -298,6 +306,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error { return reqErr } + errRes.Error.HTTPStatus = resp.Status errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } diff --git a/client_test.go b/client_test.go index 3f27b9dd..18da787a 100644 --- a/client_test.go +++ b/client_test.go @@ -134,14 +134,17 @@ func TestHandleErrorResp(t *testing.T) { client := NewClient(mockToken) testCases := []struct { - name string - httpCode int - body io.Reader - expected string + name string + httpCode int + httpStatus string + contentType string + body io.Reader + expected string }{ { - name: "401 Invalid Authentication", - httpCode: http.StatusUnauthorized, + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -152,11 +155,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: You didn't provide an API key. ....", + expected: "error, status code: 401, status: , message: You didn't provide an API key. ....", }, { - name: "401 Azure Access Denied", - httpCode: http.StatusUnauthorized, + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -165,11 +169,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.", }, { - name: "503 Model Overloaded", - httpCode: http.StatusServiceUnavailable, + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{ @@ -179,22 +184,53 @@ func TestHandleErrorResp(t *testing.T) { "code":null } }`)), - expected: "error, status code: 503, message: That model...", + expected: "error, status code: 503, status: , message: That model...", }, { - name: "503 no message (Unknown response)", - httpCode: http.StatusServiceUnavailable, + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{} }`)), - expected: "error, status code: 503, message: ", + expected: "error, status code: 503, status: , message: ", + }, + { + name: "413 Request Entity Too Large", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: bytes.NewReader([]byte(` +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`)), + expected: `error, status code: 413, status: , body: +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`, + }, + { + name: "errorReader", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: &errorReader{err: errors.New("errorReader")}, + expected: "error, reading response body: errorReader", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - testCase := &http.Response{} + testCase := &http.Response{ + Header: map[string][]string{ + "Content-Type": {tc.contentType}, + }, + } testCase.StatusCode = tc.httpCode testCase.Body = io.NopCloser(tc.body) err := client.handleErrorResp(testCase) @@ -203,12 +239,6 @@ func TestHandleErrorResp(t *testing.T) { t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) t.Fail() } - - e := &APIError{} - if !errors.As(err, &e) { - t.Errorf("(%s) Expected error to be of type APIError", tc.name) - t.Fail() - } }) } } diff --git a/common.go b/common.go index cbfda4e3..cde14154 100644 --- a/common.go +++ b/common.go @@ -4,7 +4,13 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` +} + +// CompletionTokensDetails Breakdown of tokens used in a completion. +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` } diff --git a/error.go b/error.go index 37959a27..1f6a8971 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ type APIError struct { Message string `json:"message"` Param *string `json:"param,omitempty"` Type string `json:"type"` + HTTPStatus string `json:"-"` HTTPStatusCode int `json:"-"` InnerError *InnerError `json:"innererror,omitempty"` } @@ -25,6 +26,7 @@ type InnerError struct { // RequestError provides information about generic request errors. type RequestError struct { + HTTPStatus string HTTPStatusCode int Err error } @@ -35,7 +37,7 @@ type ErrorResponse struct { func (e *APIError) Error() string { if e.HTTPStatusCode > 0 { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message) } return e.Message @@ -101,7 +103,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err) } func (e *RequestError) Unwrap() error { diff --git a/files_api_test.go b/files_api_test.go index c92162a8..aa4fda45 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -152,6 +152,7 @@ func TestGetFileContentReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) diff --git a/run.go b/run.go index 0cdec2bd..d3e755f0 100644 --- a/run.go +++ b/run.go @@ -37,8 +37,6 @@ type Run struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` - // Disable the default behavior of parallel tool calls by setting it: false. - ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` httpHeader } @@ -112,6 +110,8 @@ type RunRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or a ResponseFormat object. ResponseFormat any `json:"response_format,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread.