Skip to content

Commit

Permalink
fix(llm): sets request model name consistently across AI plugins (#13627
Browse files Browse the repository at this point in the history
) (#13633)

(cherry picked from commit ad3e19a)

Co-authored-by: Jack Tysoe <91137069+tysoekong@users.noreply.github.com>
  • Loading branch information
team-gateway-bot and tysoekong authored Sep 9, 2024
1 parent ef87a2b commit c983384
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 10 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/fix-ai-semantic-cache-model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: "Fixed an bug that AI semantic cache can't use request provided models"
type: bugfix
scope: Plugin

2 changes: 1 addition & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ function _M.post_request(conf, response_object)
-- Set the model, response, and provider names in the current try context
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = llm_state.get_request_model()
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name

-- Set the llm latency meta, and time per token usage
Expand Down
4 changes: 2 additions & 2 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ function _M:header_filter(conf)
end

if ngx.var.http_kong_debug or conf.model_name_header then
local name = conf.model.provider .. "/" .. (kong.ctx.plugin.llm_model_requested or conf.model.name)
local name = conf.model.provider .. "/" .. (llm_state.get_request_model())
kong.response.set_header("X-Kong-LLM-Model", name)
end

Expand Down Expand Up @@ -385,7 +385,7 @@ function _M:access(conf)
return bail(400, "model parameter not found in request, nor in gateway configuration")
end

kong_ctx_plugin.llm_model_requested = conf_m.model.name
llm_state.set_request_model(conf_m.model.name)

-- check the incoming format is the same as the configured LLM format
local compatible, err = llm.is_compatible(request_table, route_type)
Expand Down
8 changes: 8 additions & 0 deletions kong/llm/state.lua
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,12 @@ function _M.get_metrics(key)
return (kong.ctx.shared.llm_metrics or {})[key]
end

function _M.set_request_model(model)
kong.ctx.shared.llm_model_requested = model
end

function _M.get_request_model()
return kong.ctx.shared.llm_model_requested or "NOT_SPECIFIED"
end

return _M
1 change: 1 addition & 0 deletions kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ local function create_http_opts(conf)
end

function _M:access(conf)
llm_state.set_request_model(conf.llm.model and conf.llm.model.name)
local kong_ctx_shared = kong.ctx.shared

kong.service.request.enable_buffering()
Expand Down
1 change: 1 addition & 0 deletions kong/plugins/ai-response-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ end


function _M:access(conf)
llm_state.set_request_model(conf.llm.model and conf.llm.model.name)
local kong_ctx_shared = kong.ctx.shared

kong.service.request.enable_buffering()
Expand Down
68 changes: 61 additions & 7 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local client

lazy_setup(function()
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME })
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME, "ctx-checker-last", "ctx-checker" })

-- set up openai mock fixtures
local fixtures = {
Expand Down Expand Up @@ -274,6 +274,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
path = FILE_LOG_PATH_STATS_ONLY,
},
}
bp.plugins:insert {
name = "ctx-checker-last",
route = { id = chat_good.id },
config = {
ctx_kind = "kong.ctx.shared",
ctx_check_field = "llm_model_requested",
ctx_check_value = "gpt-3.5-turbo",
}
}

-- 200 chat good with one option
local chat_good_no_allow_override = assert(bp.routes:insert {
Expand Down Expand Up @@ -544,16 +553,16 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
--

-- 200 chat good but no model set
local chat_good = assert(bp.routes:insert {
-- 200 chat good but no model set in plugin config
local chat_good_no_model = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/openai/llm/v1/chat/good-no-model-param" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_good.id },
route = { id = chat_good_no_model.id },
config = {
route_type = "llm/v1/chat",
auth = {
Expand All @@ -572,11 +581,20 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
bp.plugins:insert {
name = "file-log",
route = { id = chat_good.id },
route = { id = chat_good_no_model.id },
config = {
path = "/dev/stdout",
},
}
bp.plugins:insert {
name = "ctx-checker-last",
route = { id = chat_good_no_model.id },
config = {
ctx_kind = "kong.ctx.shared",
ctx_check_field = "llm_model_requested",
ctx_check_value = "try-to-override-the-model",
}
}
--

-- 200 completions good using post body key
Expand Down Expand Up @@ -755,7 +773,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
},
}
--


-- start kong
assert(helpers.start_kong({
Expand All @@ -764,7 +782,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- use the custom test template to create a local mock server
nginx_conf = "spec/fixtures/custom_nginx.template",
-- make sure our plugin gets loaded
plugins = "bundled," .. PLUGIN_NAME,
plugins = "bundled,ctx-checker-last,ctx-checker," .. PLUGIN_NAME,
-- write & load declarative config, only if 'strategy=off'
declarative_config = strategy == "off" and helpers.make_yaml_file() or nil,
}, nil, nil, fixtures))
Expand Down Expand Up @@ -835,6 +853,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same(first_expected, first_got)
assert.is_true(actual_llm_latency >= 0)
assert.same(actual_time_per_token, time_per_token)
assert.same(first_got.meta.request_model, "gpt-3.5-turbo")
end)

it("does not log statistics", function()
Expand Down Expand Up @@ -1030,6 +1049,9 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

-- from ctx-checker-last plugin
assert.equals(r.headers["ctx-checker-last-llm-model-requested"], "gpt-3.5-turbo")
end)

it("good request, parses model of cjson.null", function()
Expand Down Expand Up @@ -1110,6 +1132,38 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.is_truthy(json.error)
assert.equals(json.error.message, "request format not recognised")
end)

-- check that kong.ctx.shared.llm_model_requested is set
it("good request setting model from client body", function()
local r = client:get("/openai/llm/v1/chat/good-no-model-param", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "openai/try-to-override-the-model")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

-- from ctx-checker-last plugin
assert.equals(r.headers["ctx-checker-last-llm-model-requested"], "try-to-override-the-model")
end)

end)

describe("openai llm/v1/completions", function()
Expand Down

0 comments on commit c983384

Please sign in to comment.