diff --git a/billing.go b/billing.go new file mode 100644 index 000000000..33f3ce9bf --- /dev/null +++ b/billing.go @@ -0,0 +1,70 @@ +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +const billingUsageSuffix = "/billing/usage" + +type CostLineItemResponse struct { + Name string `json:"name"` + Cost float64 `json:"cost"` // in cents +} + +type DailyCostResponse struct { + TimestampRaw float64 `json:"timestamp"` + LineItems []CostLineItemResponse `json:"line_items"` + + Time time.Time `json:"-"` +} + +type BillingUsageResponse struct { + Object string `json:"object"` + DailyCosts []DailyCostResponse `json:"daily_costs"` + TotalUsage float64 `json:"total_usage"` // in cents + + httpHeader +} + +// currently the OpenAI usage API is not publicly documented and will explictly +// reject requests using an API key authorization. however, it can be utilized +// logging into https://platform.openai.com/usage and retrieving your session +// key from the browser console. session keys have the form 'sess-'. +var ( + BillingAPIKeyNotAllowedErrMsg = "Your request to GET /dashboard/billing/usage must be made with a session key (that is, it can only be made from the browser)." //nolint:lll + ErrSessKeyRequired = errors.New("an OpenAI API key cannot be used for this request; a session key is required instead") //nolint:lll +) + +// GetBillingUsage — API call to Get billing usage details. +func (c *Client) GetBillingUsage(ctx context.Context, startDate time.Time, + endDate time.Time) (response BillingUsageResponse, err error) { + startDateArg := fmt.Sprintf("start_date=%v", startDate.Format(time.DateOnly)) + endDateArg := fmt.Sprintf("end_date=%v", endDate.Format(time.DateOnly)) + queryParams := fmt.Sprintf("%v&%v", startDateArg, endDateArg) + urlSuffix := fmt.Sprintf("%v?%v", billingUsageSuffix, queryParams) + + req, err := c.newRequest(ctx, http.MethodGet, c.fullDashboardURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + if err != nil { + if strings.Contains(err.Error(), BillingAPIKeyNotAllowedErrMsg) { + err = ErrSessKeyRequired + } + return + } + + for idx, d := range response.DailyCosts { + dTime := time.Unix(int64(d.TimestampRaw), 0) + response.DailyCosts[idx].Time = dTime + } + + return +} diff --git a/billing_test.go b/billing_test.go new file mode 100644 index 000000000..a61decec8 --- /dev/null +++ b/billing_test.go @@ -0,0 +1,116 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +const ( + TestTotCost = float64(126.234) + TestEndDate = "2023-11-30" + TestStartDate = "2023-11-01" + TestSessionKey = "sess-whatever" + TestAPIKey = "sk-whatever" +) + +func TestBillingUsageAPIKey(t *testing.T) { + client, server, teardown := setupOpenAITestServerWithAuth(TestAPIKey) + defer teardown() + server.RegisterHandler("/dashboard/billing/usage", handleBillingEndpoint) + + ctx := context.Background() + + endDate, err := time.Parse(time.DateOnly, TestEndDate) + checks.NoError(t, err) + startDate, err := time.Parse(time.DateOnly, TestStartDate) + checks.NoError(t, err) + + _, err = client.GetBillingUsage(ctx, startDate, endDate) + checks.HasError(t, err) +} + +func TestBillingUsageSessKey(t *testing.T) { + client, server, teardown := setupOpenAITestServerWithAuth(TestSessionKey) + defer teardown() + server.RegisterHandler("/dashboard/billing/usage", handleBillingEndpoint) + + ctx := context.Background() + endDate, err := time.Parse(time.DateOnly, TestEndDate) + checks.NoError(t, err) + startDate, err := time.Parse(time.DateOnly, TestStartDate) + checks.NoError(t, err) + + resp, err := client.GetBillingUsage(ctx, startDate, endDate) + checks.NoError(t, err) + + if resp.TotalUsage != TestTotCost { + t.Errorf("expected total cost %v but got %v", TestTotCost, + resp.TotalUsage) + } + for idx, dc := range resp.DailyCosts { + if dc.Time.Compare(startDate) < 0 { + t.Errorf("expected daily cost%v date(%v) before start date %v", idx, + dc.Time, TestStartDate) + } + if dc.Time.Compare(endDate) > 0 { + t.Errorf("expected daily cost%v date(%v) after end date %v", idx, + dc.Time, TestEndDate) + } + } +} + +// handleBillingEndpoint Handles the billing usage endpoint by the test server. +func handleBillingEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if strings.Contains(r.Header.Get("Authorization"), TestAPIKey) { + http.Error(w, openai.BillingAPIKeyNotAllowedErrMsg, http.StatusUnauthorized) + return + } + + var resBytes []byte + + dailyCosts := make([]openai.DailyCostResponse, 0) + + d, _ := time.Parse(time.DateOnly, TestStartDate) + d = d.Add(24 * time.Hour) + dailyCosts = append(dailyCosts, openai.DailyCostResponse{ + TimestampRaw: float64(d.Unix()), + LineItems: []openai.CostLineItemResponse{ + {Name: "GPT-4 Turbo", Cost: 0.12}, + {Name: "Audio models", Cost: 0.24}, + }, + Time: time.Time{}, + }) + d = d.Add(24 * time.Hour) + dailyCosts = append(dailyCosts, openai.DailyCostResponse{ + TimestampRaw: float64(d.Unix()), + LineItems: []openai.CostLineItemResponse{ + {Name: "image models", Cost: 0.56}, + }, + Time: time.Time{}, + }) + res := &openai.BillingUsageResponse{ + Object: "list", + DailyCosts: dailyCosts, + TotalUsage: TestTotCost, + } + + resBytes, err := json.Marshal(res) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + fmt.Fprintln(w, string(resBytes)) +} diff --git a/client.go b/client.go index d5d555c3d..852036711 100644 --- a/client.go +++ b/client.go @@ -260,6 +260,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } +// fullDashboardURL returns full URL for a dashboard request. +func (c *Client) fullDashboardURL(suffix string, _ ...any) string { + // @todo this needs to be updated for c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD + + return fmt.Sprintf("%s%s", c.config.DashboardBaseURL, suffix) +} + func (c *Client) handleErrorResp(resp *http.Response) error { var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) diff --git a/config.go b/config.go index 1347567d7..79ac919cd 100644 --- a/config.go +++ b/config.go @@ -7,6 +7,7 @@ import ( const ( openaiAPIURLv1 = "https://api.openai.com/v1" + openaiAPIDashboardURL = "https://api.openai.com/dashboard" defaultEmptyMessagesLimit uint = 300 azureAPIPrefix = "openai" @@ -31,6 +32,7 @@ type ClientConfig struct { authToken string BaseURL string + DashboardBaseURL string OrgID string APIType APIType APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD @@ -45,6 +47,7 @@ func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ authToken: authToken, BaseURL: openaiAPIURLv1, + DashboardBaseURL: openaiAPIDashboardURL, APIType: APITypeOpenAI, AssistantVersion: defaultAssistantVersion, OrgID: "", diff --git a/internal/test/server.go b/internal/test/server.go index 127d4c16f..74ffb5d95 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -16,11 +16,15 @@ func GetTestToken() string { type ServerTest struct { handlers map[string]handler + authKey string } type handler func(w http.ResponseWriter, r *http.Request) -func NewTestServer() *ServerTest { - return &ServerTest{handlers: make(map[string]handler)} +func NewTestServer(authKeyIn string) *ServerTest { + return &ServerTest{ + handlers: make(map[string]handler), + authKey: authKeyIn, + } } func (ts *ServerTest) RegisterHandler(path string, handler handler) { @@ -36,7 +40,7 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path) // check auth - if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { + if r.Header.Get("Authorization") != "Bearer "+ts.authKey && r.Header.Get("api-key") != ts.authKey { w.WriteHeader(http.StatusUnauthorized) return } diff --git a/openai_test.go b/openai_test.go index 729d8880c..70da7cd8a 100644 --- a/openai_test.go +++ b/openai_test.go @@ -6,18 +6,23 @@ import ( ) func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { - server = test.NewTestServer() + return setupOpenAITestServerWithAuth(test.GetTestToken()) +} + +func setupOpenAITestServerWithAuth(authKey string) (client *openai.Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer(authKey) ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := openai.DefaultConfig(test.GetTestToken()) + config := openai.DefaultConfig(authKey) config.BaseURL = ts.URL + "/v1" + config.DashboardBaseURL = ts.URL + "/dashboard" client = openai.NewClientWithConfig(config) return } func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { - server = test.NewTestServer() + server = test.NewTestServer(test.GetTestToken()) ts := server.OpenAITestServer() ts.Start() teardown = ts.Close