From 38bdc812df391bcec3d7defda2a456ea00bb54e5 Mon Sep 17 00:00:00 2001
From: eiixy <990656271@qq.com>
Date: Thu, 26 Sep 2024 18:25:56 +0800
Subject: [PATCH 1/4] Optimize Client Error Return (#856)
* update client error return
* update client_test.go
* update client_test.go
* update file_api_test.go
* update client_test.go
* update client_test.go
---
client.go | 9 ++++++
client_test.go | 76 +++++++++++++++++++++++++++++++++--------------
error.go | 6 ++--
files_api_test.go | 1 +
4 files changed, 67 insertions(+), 25 deletions(-)
diff --git a/client.go b/client.go
index 9f547e7cb..583244fe1 100644
--- a/client.go
+++ b/client.go
@@ -285,10 +285,18 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB
}
func (c *Client) handleErrorResp(resp *http.Response) error {
+ if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("error, reading response body: %w", err)
+ }
+ return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body)
+ }
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := &RequestError{
+ HTTPStatus: resp.Status,
HTTPStatusCode: resp.StatusCode,
Err: err,
}
@@ -298,6 +306,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error {
return reqErr
}
+ errRes.Error.HTTPStatus = resp.Status
errRes.Error.HTTPStatusCode = resp.StatusCode
return errRes.Error
}
diff --git a/client_test.go b/client_test.go
index 3f27b9dd7..18da787a0 100644
--- a/client_test.go
+++ b/client_test.go
@@ -134,14 +134,17 @@ func TestHandleErrorResp(t *testing.T) {
client := NewClient(mockToken)
testCases := []struct {
- name string
- httpCode int
- body io.Reader
- expected string
+ name string
+ httpCode int
+ httpStatus string
+ contentType string
+ body io.Reader
+ expected string
}{
{
- name: "401 Invalid Authentication",
- httpCode: http.StatusUnauthorized,
+ name: "401 Invalid Authentication",
+ httpCode: http.StatusUnauthorized,
+ contentType: "application/json",
body: bytes.NewReader([]byte(
`{
"error":{
@@ -152,11 +155,12 @@ func TestHandleErrorResp(t *testing.T) {
}
}`,
)),
- expected: "error, status code: 401, message: You didn't provide an API key. ....",
+ expected: "error, status code: 401, status: , message: You didn't provide an API key. ....",
},
{
- name: "401 Azure Access Denied",
- httpCode: http.StatusUnauthorized,
+ name: "401 Azure Access Denied",
+ httpCode: http.StatusUnauthorized,
+ contentType: "application/json",
body: bytes.NewReader([]byte(
`{
"error":{
@@ -165,11 +169,12 @@ func TestHandleErrorResp(t *testing.T) {
}
}`,
)),
- expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.",
+ expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.",
},
{
- name: "503 Model Overloaded",
- httpCode: http.StatusServiceUnavailable,
+ name: "503 Model Overloaded",
+ httpCode: http.StatusServiceUnavailable,
+ contentType: "application/json",
body: bytes.NewReader([]byte(`
{
"error":{
@@ -179,22 +184,53 @@ func TestHandleErrorResp(t *testing.T) {
"code":null
}
}`)),
- expected: "error, status code: 503, message: That model...",
+ expected: "error, status code: 503, status: , message: That model...",
},
{
- name: "503 no message (Unknown response)",
- httpCode: http.StatusServiceUnavailable,
+ name: "503 no message (Unknown response)",
+ httpCode: http.StatusServiceUnavailable,
+ contentType: "application/json",
body: bytes.NewReader([]byte(`
{
"error":{}
}`)),
- expected: "error, status code: 503, message: ",
+ expected: "error, status code: 503, status: , message: ",
+ },
+ {
+ name: "413 Request Entity Too Large",
+ httpCode: http.StatusRequestEntityTooLarge,
+ contentType: "text/html",
+ body: bytes.NewReader([]byte(`
+
413 Request Entity Too Large
+
+413 Request Entity Too Large
+
nginx
+
+`)),
+ expected: `error, status code: 413, status: , body:
+413 Request Entity Too Large
+
+413 Request Entity Too Large
+
nginx
+
+`,
+ },
+ {
+ name: "errorReader",
+ httpCode: http.StatusRequestEntityTooLarge,
+ contentType: "text/html",
+ body: &errorReader{err: errors.New("errorReader")},
+ expected: "error, reading response body: errorReader",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- testCase := &http.Response{}
+ testCase := &http.Response{
+ Header: map[string][]string{
+ "Content-Type": {tc.contentType},
+ },
+ }
testCase.StatusCode = tc.httpCode
testCase.Body = io.NopCloser(tc.body)
err := client.handleErrorResp(testCase)
@@ -203,12 +239,6 @@ func TestHandleErrorResp(t *testing.T) {
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
t.Fail()
}
-
- e := &APIError{}
- if !errors.As(err, &e) {
- t.Errorf("(%s) Expected error to be of type APIError", tc.name)
- t.Fail()
- }
})
}
}
diff --git a/error.go b/error.go
index 37959a272..1f6a8971d 100644
--- a/error.go
+++ b/error.go
@@ -13,6 +13,7 @@ type APIError struct {
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
+ HTTPStatus string `json:"-"`
HTTPStatusCode int `json:"-"`
InnerError *InnerError `json:"innererror,omitempty"`
}
@@ -25,6 +26,7 @@ type InnerError struct {
// RequestError provides information about generic request errors.
type RequestError struct {
+ HTTPStatus string
HTTPStatusCode int
Err error
}
@@ -35,7 +37,7 @@ type ErrorResponse struct {
func (e *APIError) Error() string {
if e.HTTPStatusCode > 0 {
- return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message)
+ return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message)
}
return e.Message
@@ -101,7 +103,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
}
func (e *RequestError) Error() string {
- return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err)
+ return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err)
}
func (e *RequestError) Unwrap() error {
diff --git a/files_api_test.go b/files_api_test.go
index c92162a84..aa4fda458 100644
--- a/files_api_test.go
+++ b/files_api_test.go
@@ -152,6 +152,7 @@ func TestGetFileContentReturnError(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, wantErrorResp)
})
From 7f80303cc393edf2f6806ca37668346f8fa6247e Mon Sep 17 00:00:00 2001
From: Alex Philipp
Date: Thu, 26 Sep 2024 05:26:22 -0500
Subject: [PATCH 2/4] Fix max_completion_tokens (#860)
The json tag is incorrect, and results in an error from the API when using the o1 model.
I didn't modify the struct field name to maintain compatibility if anyone else had started using it, but it wouldn't work for them either.
---
chat.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/chat.go b/chat.go
index d47c95e4f..dd99c530e 100644
--- a/chat.go
+++ b/chat.go
@@ -209,7 +209,7 @@ type ChatCompletionRequest struct {
MaxTokens int `json:"max_tokens,omitempty"`
// MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion,
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning
- MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"`
+ MaxCompletionsTokens int `json:"max_completion_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
From e9d8485e90092b8adcce82fdd0dcd7cf10327e8d Mon Sep 17 00:00:00 2001
From: Jialin Tian
Date: Thu, 26 Sep 2024 18:26:54 +0800
Subject: [PATCH 3/4] fix: ParallelToolCalls should be added to RunRequest
(#861)
---
run.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/run.go b/run.go
index 0cdec2bdc..d3e755f05 100644
--- a/run.go
+++ b/run.go
@@ -37,8 +37,6 @@ type Run struct {
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// ThreadTruncationStrategy defines the truncation strategy to use for the thread.
TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"`
- // Disable the default behavior of parallel tool calls by setting it: false.
- ParallelToolCalls any `json:"parallel_tool_calls,omitempty"`
httpHeader
}
@@ -112,6 +110,8 @@ type RunRequest struct {
ToolChoice any `json:"tool_choice,omitempty"`
// This can be either a string or a ResponseFormat object.
ResponseFormat any `json:"response_format,omitempty"`
+ // Disable the default behavior of parallel tool calls by setting it: false.
+ ParallelToolCalls any `json:"parallel_tool_calls,omitempty"`
}
// ThreadTruncationStrategy defines the truncation strategy to use for the thread.
From fdd59d93413154cd07b2e46a428b15eda40b26e2 Mon Sep 17 00:00:00 2001
From: Liu Shuang
Date: Thu, 26 Sep 2024 18:30:56 +0800
Subject: [PATCH 4/4] feat: usage struct add CompletionTokensDetails (#863)
---
common.go | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
diff --git a/common.go b/common.go
index cbfda4e3c..cde14154a 100644
--- a/common.go
+++ b/common.go
@@ -4,7 +4,13 @@ package openai
// Usage Represents the total token usage per request to OpenAI.
type Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"`
+}
+
+// CompletionTokensDetails Breakdown of tokens used in a completion.
+type CompletionTokensDetails struct {
+ ReasoningTokens int `json:"reasoning_tokens"`
}