Skip to content

Commit

Permalink
fix: Fix possible type-casting related panics in ai-proxy plugin (#1127)
Browse files Browse the repository at this point in the history
  • Loading branch information
CH3CHO committed Jul 16, 2024
1 parent f069ad5 commit d5a9ff3
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 19 deletions.
8 changes: 4 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())

if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnRequestBody(ctx, apiName, body, log)
if err == nil {
return action
Expand Down Expand Up @@ -139,7 +139,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
}

if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
if err == nil {
return action
Expand Down Expand Up @@ -171,7 +171,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))

if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
if err == nil && modifiedChunk != nil {
return modifiedChunk
Expand All @@ -193,7 +193,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
//log.Debugf("response body: %s", string(body))

if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok {
apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseBody(ctx, apiName, body, log)
if err == nil {
return action
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *
return &chatCompletionResponse{
Id: response.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: []chatCompletionChoice{choice},
Expand All @@ -321,7 +321,7 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp
return &chatCompletionResponse{
Id: response.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: []chatCompletionChoice{choice},
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp
return &chatCompletionResponse{
Id: origResponse.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: []chatCompletionChoice{choice},
Expand Down Expand Up @@ -356,7 +356,7 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG
return &chatCompletionResponse{
Id: response.Message.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Object: objectChatCompletionChunk,
Choices: []chatCompletionChoice{choice},
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
openAIFormattedChunk := &chatCompletionResponse{
Id: hunyuanFormattedChunk.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletionChunk,
Usage: usage{
Expand Down Expand Up @@ -470,7 +470,7 @@ func (m *hunyuanProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, h
return &chatCompletionResponse{
Id: hunyuanResponse.Response.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: choices,
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName Api
return types.ActionContinue, nil
}
// 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody()
model := ctx.GetContext(ctxKeyFinalRequestModel)
if model != nil {
_, ok := chatCompletionProModels[model.(string)]
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
if model != "" {
_, ok := chatCompletionProModels[model]
if !ok {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
Expand Down
9 changes: 3 additions & 6 deletions plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,7 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
receivedBody = append(bufferedStreamingBody, chunk...)
}

incrementalStreaming, err := ctx.GetContext(ctxKeyIncrementalStreaming).(bool)
if !err {
incrementalStreaming = false
}
incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false)

eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1

Expand Down Expand Up @@ -387,7 +384,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
return &chatCompletionResponse{
Id: qwenResponse.RequestId,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: choices,
Expand All @@ -403,7 +400,7 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
baseMessage := chatCompletionResponse{
Id: qwenResponse.RequestId,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: make([]chatCompletionChoice, 0),
SystemFingerprint: "",
Object: objectChatCompletionChunk,
Expand Down
16 changes: 16 additions & 0 deletions plugins/wasm-go/pkg/wrapper/plugin_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type HttpContext interface {
Method() string
SetContext(key string, value interface{})
GetContext(key string) interface{}
GetBoolContext(key string, defaultValue bool) bool
GetStringContext(key, defaultValue string) string
// If the onHttpRequestBody handle is not set, the request body will not be read by default
DontReadRequestBody()
// If the onHttpResponseBody handle is not set, the request body will not be read by default
Expand Down Expand Up @@ -297,6 +299,20 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetContext(key string) interface{} {
return ctx.userContext[key]
}

func (ctx *CommonHttpCtx[PluginConfig]) GetBoolContext(key string, defaultValue bool) bool {
if b, ok := ctx.userContext[key].(bool); ok {
return b
}
return defaultValue
}

func (ctx *CommonHttpCtx[PluginConfig]) GetStringContext(key, defaultValue string) string {
if s, ok := ctx.userContext[key].(string); ok {
return s
}
return defaultValue
}

func (ctx *CommonHttpCtx[PluginConfig]) Scheme() string {
proxywasm.SetEffectiveContext(ctx.contextID)
return GetRequestScheme()
Expand Down

0 comments on commit d5a9ff3

Please sign in to comment.