diff --git a/changelog/unreleased/kong/fix-ai-chunking.yml b/changelog/unreleased/kong/fix-ai-chunking.yml new file mode 100644 index 00000000000..65d66dc4d21 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-chunking.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where response streaming in Gemini and Bedrock providers was returning whole chat responses in one chunk." +type: bugfix +scope: Plugin diff --git a/kong-3.10.0-0.rockspec b/kong-3.10.0-0.rockspec index 1d17a970c9c..31f68741f8c 100644 --- a/kong-3.10.0-0.rockspec +++ b/kong-3.10.0-0.rockspec @@ -658,6 +658,7 @@ build = { ["kong.llm.plugin.shared-filters.enable-buffering"] = "kong/llm/plugin/shared-filters/enable-buffering.lua", ["kong.llm.plugin.shared-filters.normalize-json-response"] = "kong/llm/plugin/shared-filters/normalize-json-response.lua", ["kong.llm.plugin.shared-filters.normalize-request"] = "kong/llm/plugin/shared-filters/normalize-request.lua", + ["kong.llm.plugin.shared-filters.normalize-response-header"] = "kong/llm/plugin/shared-filters/normalize-response-header.lua", ["kong.llm.plugin.shared-filters.normalize-sse-chunk"] = "kong/llm/plugin/shared-filters/normalize-sse-chunk.lua", ["kong.llm.plugin.shared-filters.parse-json-response"] = "kong/llm/plugin/shared-filters/parse-json-response.lua", ["kong.llm.plugin.shared-filters.parse-request"] = "kong/llm/plugin/shared-filters/parse-request.lua", diff --git a/kong/llm/plugin/base.lua b/kong/llm/plugin/base.lua index c678ec43d0a..531c45583db 100644 --- a/kong/llm/plugin/base.lua +++ b/kong/llm/plugin/base.lua @@ -100,28 +100,6 @@ function MetaPlugin:rewrite(sub_plugin, conf) end function MetaPlugin:header_filter(sub_plugin, conf) - -- for error and exit response, just use plaintext headers - if kong.response.get_source() == "service" then - -- we use openai's streaming mode (SSE) - if get_global_ctx("stream_mode") then - -- we are going to send plaintext event-stream frames for ALL models - kong.response.set_header("Content-Type", "text/event-stream") - -- TODO: disable gzip for SSE because it needs immediate flush for each chunk - -- and seems nginx doesn't support it - - elseif get_global_ctx("accept_gzip") then - -- for gzip response, don't set content-length at all to align with upstream - kong.response.clear_header("Content-Length") - kong.response.set_header("Content-Encoding", "gzip") - - else - kong.response.clear_header("Content-Encoding") - end - - else - kong.response.clear_header("Content-Encoding") - end - run_stage(STAGES.REQ_POST_PROCESSING, sub_plugin, conf) -- TODO: order this in better place run_stage(STAGES.RES_INTROSPECTION, sub_plugin, conf) diff --git a/kong/llm/plugin/shared-filters/enable-buffering.lua b/kong/llm/plugin/shared-filters/enable-buffering.lua index eee9757fce1..055713c2e00 100644 --- a/kong/llm/plugin/shared-filters/enable-buffering.lua +++ b/kong/llm/plugin/shared-filters/enable-buffering.lua @@ -4,12 +4,15 @@ local _M = { DESCRIPTION = "set the response to buffering mode", } +local ai_plugin_ctx = require("kong.llm.plugin.ctx") +local get_global_ctx, _ = ai_plugin_ctx.get_global_accessors(_M.NAME) + function _M:run(_) - if ngx.get_phase() == "access" then + if ngx.get_phase() == "access" and (not get_global_ctx("stream_mode")) then kong.service.request.enable_buffering() end return true end -return _M \ No newline at end of file +return _M diff --git a/kong/llm/plugin/shared-filters/normalize-response-header.lua b/kong/llm/plugin/shared-filters/normalize-response-header.lua new file mode 100644 index 00000000000..3ab240a1527 --- /dev/null +++ b/kong/llm/plugin/shared-filters/normalize-response-header.lua @@ -0,0 +1,41 @@ +local _M = { + NAME = "normalize-response-header", + STAGE = "REQ_POST_PROCESSING", + DESCRIPTION = "normalize upstream response headers", +} + +local ai_plugin_ctx = require("kong.llm.plugin.ctx") +local get_global_ctx, _ = ai_plugin_ctx.get_global_accessors(_M.NAME) + +local FILTER_OUTPUT_SCHEMA = { + stream_content_type = "string", +} + +local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA) + +function _M:run(_) + -- for error and exit response, just use plaintext headers + if kong.response.get_source() == "service" then + -- we use openai's streaming mode (SSE) + if get_global_ctx("stream_mode") then + -- we are going to send plaintext event-stream frames for ALL models, + -- but we capture the original incoming content-type for the chunk-parser later. + set_ctx("stream_content_type", kong.service.response.get_header("Content-Type")) + kong.response.set_header("Content-Type", "text/event-stream") + + -- TODO: disable gzip for SSE because it needs immediate flush for each chunk + -- and seems nginx doesn't support it + elseif get_global_ctx("accept_gzip") then + -- for gzip response, don't set content-length at all to align with upstream + kong.response.clear_header("Content-Length") + kong.response.set_header("Content-Encoding", "gzip") + else + kong.response.clear_header("Content-Encoding") + end + else + kong.response.clear_header("Content-Encoding") + end + return true +end + +return _M diff --git a/kong/llm/plugin/shared-filters/parse-sse-chunk.lua b/kong/llm/plugin/shared-filters/parse-sse-chunk.lua index b0c9195a8e2..ac1dd0da521 100644 --- a/kong/llm/plugin/shared-filters/parse-sse-chunk.lua +++ b/kong/llm/plugin/shared-filters/parse-sse-chunk.lua @@ -15,10 +15,10 @@ local FILTER_OUTPUT_SCHEMA = { local get_global_ctx, _ = ai_plugin_ctx.get_global_accessors(_M.NAME) local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA) - local function handle_streaming_frame(conf, chunk, finished) - local content_type = kong.service.response.get_header("Content-Type") + local content_type = ai_plugin_ctx.get_namespaced_ctx("normalize-response-header", "stream_content_type") + local normalized_content_type = content_type and content_type:sub(1, (content_type:find(";") or 0) - 1) if normalized_content_type and (not ai_shared._SUPPORTED_STREAMING_CONTENT_TYPES[normalized_content_type]) then return true diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 6c0320411e7..4a7eaff055d 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -7,7 +7,7 @@ local AIPlugin = ai_plugin_base.define(NAME, PRIORITY) local SHARED_FILTERS = { "parse-request", "normalize-request", "enable-buffering", - "parse-sse-chunk", "normalize-sse-chunk", + "normalize-response-header", "parse-sse-chunk", "normalize-sse-chunk", "parse-json-response", "normalize-json-response", "serialize-analytics", } diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua index 6ba0d7bdf97..e9cb676a0e8 100644 --- a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -640,6 +640,8 @@ for _, strategy in helpers.all_strategies() do assert.equal(#events, 8) assert.equal(buf:tostring(), "The answer to 1 + 1 is 2.") + -- to verifiy not enable `kong.service.request.enable_buffering()` + assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) it("good stream request openai with partial split chunks", function() @@ -728,6 +730,8 @@ for _, strategy in helpers.all_strategies() do assert.same(tonumber(string.format("%.3f", actual_time_per_token)), tonumber(string.format("%.3f", time_per_token))) assert.match_re(actual_request_log, [[.*content.*What is 1 \+ 1.*]]) assert.match_re(actual_response_log, [[.*content.*The answer.*]]) + -- to verifiy not enable `kong.service.request.enable_buffering()` + assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) it("good stream request cohere", function() @@ -790,6 +794,8 @@ for _, strategy in helpers.all_strategies() do assert.equal(#events, 17) assert.equal(buf:tostring(), "1 + 1 = 2. This is the most basic example of addition.") + -- to verifiy not enable `kong.service.request.enable_buffering()` + assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) it("good stream request anthropic", function() @@ -852,6 +858,8 @@ for _, strategy in helpers.all_strategies() do assert.equal(#events, 8) assert.equal(buf:tostring(), "1 + 1 = 2") + -- to verifiy not enable `kong.service.request.enable_buffering()` + assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) it("bad request is returned to the client not-streamed", function() @@ -902,6 +910,8 @@ for _, strategy in helpers.all_strategies() do assert.equal(#events, 1) assert.equal(res.status, 400) + -- to verifiy not enable `kong.service.request.enable_buffering()` + assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) end)