diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 0406478ff6..6023b4abe8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -89,7 +89,8 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam return types.ActionContinue, replaceJsonRequestBody(request, log) } - err := m.getContextContent(func(content string, err error) { + apiKey := m.config.GetOrSetTokenWithContext(ctx) + err := m.getContextContent(apiKey, func(content string, err error) { defer func() { _ = proxywasm.ResumeHttpRequest() }() @@ -114,13 +115,13 @@ func (m *moonshotProvider) performChatCompletion(ctx wrapper.HttpContext, fileCo return replaceJsonRequestBody(request, log) } -func (m *moonshotProvider) getContextContent(callback func(string, error), log wrapper.Log) error { +func (m *moonshotProvider) getContextContent(apiKey string, callback func(string, error), log wrapper.Log) error { if m.config.moonshotFileId != "" { if m.fileContent != "" { callback(m.fileContent, nil) return nil } - return m.sendRequest(http.MethodGet, "/v1/files/"+m.config.moonshotFileId+"/content", "", + return m.sendRequest(http.MethodGet, "/v1/files/"+m.config.moonshotFileId+"/content", "", apiKey, func(statusCode int, responseHeaders http.Header, responseBody []byte) { responseString := string(responseBody) if statusCode != http.StatusOK { @@ -141,13 +142,13 @@ func (m *moonshotProvider) getContextContent(callback func(string, error), log w return errors.New("both moonshotFileId and context are not configured") } -func (m *moonshotProvider) sendRequest(method, path string, body string, callback wrapper.ResponseCallback) error { +func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callback wrapper.ResponseCallback) error { switch method { case http.MethodGet: - headers := util.CreateHeaders("Authorization", "Bearer "+m.config.GetRandomToken()) + headers := util.CreateHeaders("Authorization", "Bearer "+apiKey) return m.client.Get(path, headers, callback, m.config.timeout) case http.MethodPost: - headers := util.CreateHeaders("Authorization", "Bearer "+m.config.GetRandomToken(), "Content-Type", "application/json") + headers := util.CreateHeaders("Authorization", "Bearer "+apiKey, "Content-Type", "application/json") return m.client.Post(path, headers, []byte(body), callback, m.config.timeout) default: return errors.New("unsupported method: " + method) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index f446874eee..b3f29feda5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -290,6 +290,15 @@ func (c *ProviderConfig) Validate() error { return nil } +func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string { + ctxApiKey := ctx.GetContext(ctxKeyApiName) + if ctxApiKey == nil { + ctxApiKey = c.GetRandomToken() + ctx.SetContext(ctxKeyApiName, ctxApiKey) + } + return ctxApiKey.(string) +} + func (c *ProviderConfig) GetRandomToken() string { apiTokens := c.apiTokens count := len(apiTokens)