diff --git a/plugins/wasm-cpp/extensions/model_mapper/BUILD b/plugins/wasm-cpp/extensions/model_mapper/BUILD new file mode 100644 index 0000000000..da489367c7 --- /dev/null +++ b/plugins/wasm-cpp/extensions/model_mapper/BUILD @@ -0,0 +1,70 @@ +# Copyright (c) 2022 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary") +load("//bazel:wasm.bzl", "declare_wasm_image_targets") + +proxy_wasm_cc_binary( + name = "model_mapper.wasm", + srcs = [ + "plugin.cc", + "plugin.h", + ], + deps = [ + "@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics_higress", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "//common:json_util", + "//common:http_util", + "//common:rule_util", + ], +) + +cc_library( + name = "model_mapper_lib", + srcs = [ + "plugin.cc", + ], + hdrs = [ + "plugin.h", + ], + copts = ["-DNULL_PLUGIN"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "//common:json_util", + "@proxy_wasm_cpp_host//:lib", + "//common:http_util_nullvm", + "//common:rule_util_nullvm", + ], +) + +cc_test( + name = "model_mapper_test", + srcs = [ + "plugin_test.cc", + ], + copts = ["-DNULL_PLUGIN"], + deps = [ + ":model_mapper_lib", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@proxy_wasm_cpp_host//:lib", + ], +) + +declare_wasm_image_targets( + name = "model_mapper", + wasm_file = ":model_mapper.wasm", +) diff --git a/plugins/wasm-cpp/extensions/model_mapper/README.md b/plugins/wasm-cpp/extensions/model_mapper/README.md new file mode 100644 index 0000000000..1e9fc747f7 --- /dev/null +++ b/plugins/wasm-cpp/extensions/model_mapper/README.md @@ -0,0 +1,63 @@ +## 功能说明 +`model-mapper`插件实现了基于LLM协议中的model参数路由的功能 + +## 配置字段 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | +| `modelKey` | string | 选填 | model | 请求body中model参数的位置 | +| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | +| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 ## 运行属性 + +插件执行阶段:认证阶段 +插件执行优先级:800 + | +## 效果说明 + +如下配置 + +```yaml +modelMapping: + 'gpt-4-*': "qwen-max" + 'gpt-4o': "qwen-vl-plus" + '*': "qwen-turbo" +``` + +开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo` + +例如原本的请求是: + +```json +{ + "model": "gpt-4o", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + + +经过这个插件后,原始的 LLM 请求体将被改成: + +```json +{ + "model": "qwen-vl-plus", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` diff --git a/plugins/wasm-cpp/extensions/model_mapper/README_EN.md b/plugins/wasm-cpp/extensions/model_mapper/README_EN.md new file mode 100644 index 0000000000..f38cc84f68 --- /dev/null +++ b/plugins/wasm-cpp/extensions/model_mapper/README_EN.md @@ -0,0 +1,65 @@ +## Function Description +The `model-mapper` plugin implements the functionality of routing based on the model parameter in the LLM protocol. + +## Configuration Fields + +| Name | Data Type | Filling Requirement | Default Value | Description | +| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | +| `modelKey` | string | Optional | model | The location of the model parameter in the request body. | +| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model names in the request to the model names supported by the service provider.
1. Supports prefix matching. For example, use "gpt-3-*" to match all models whose names start with “gpt-3-”;
2. Supports using "*" as the key to configure a generic fallback mapping relationship;
3. If the target name in the mapping is an empty string "", it means to keep the original model name. | +| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only applies to requests with these specific path suffixes. | + +## Runtime Properties + +Plugin execution phase: Authentication phase +Plugin execution priority: 800 + +## Effect Description + +With the following configuration: + +```yaml +modelMapping: + 'gpt-4-*': "qwen-max" + 'gpt-4o': "qwen-vl-plus" + '*': "qwen-turbo" +``` + +After enabling, model parameters starting with `gpt-4-` will be rewritten to `qwen-max`, `gpt-4o` will be rewritten to `qwen-vl-plus`, and all other models will be rewritten to `qwen-turbo`. + +For example, if the original request was: + +```json +{ + "model": "gpt-4o", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the GitHub address of the main repository for the higress project?" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + + +After processing by this plugin, the original LLM request body will be modified to: + +```json +{ + "model": "qwen-vl-plus", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the GitHub address of the main repository for the higress project?" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` diff --git a/plugins/wasm-cpp/extensions/model_mapper/plugin.cc b/plugins/wasm-cpp/extensions/model_mapper/plugin.cc new file mode 100644 index 0000000000..69c2e7c7c5 --- /dev/null +++ b/plugins/wasm-cpp/extensions/model_mapper/plugin.cc @@ -0,0 +1,243 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/model_mapper/plugin.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "common/http_util.h" +#include "common/json_util.h" + +using ::nlohmann::json; +using ::Wasm::Common::JsonArrayIterate; +using ::Wasm::Common::JsonGetField; +using ::Wasm::Common::JsonObjectIterate; +using ::Wasm::Common::JsonValueAs; + +#ifdef NULL_PLUGIN + +namespace proxy_wasm { +namespace null_plugin { +namespace model_mapper { + +PROXY_WASM_NULL_PLUGIN_REGISTRY + +#endif + +static RegisterContextFactory register_ModelMapper( + CONTEXT_FACTORY(PluginContext), ROOT_FACTORY(PluginRootContext)); + +namespace { + +constexpr std::string_view SetDecoderBufferLimitKey = + "SetRequestBodyBufferLimit"; +constexpr std::string_view DefaultMaxBodyBytes = "10485760"; + +} // namespace + +bool PluginRootContext::parsePluginConfig(const json& configuration, + ModelMapperConfigRule& rule) { + if (auto it = configuration.find("modelKey"); it != configuration.end()) { + if (it->is_string()) { + rule.model_key_ = it->get(); + } else { + LOG_ERROR("Invalid type for modelKey. Expected string."); + return false; + } + } + + if (auto it = configuration.find("modelMapping"); it != configuration.end()) { + if (!it->is_object()) { + LOG_ERROR("Invalid type for modelMapping. Expected object."); + return false; + } + auto model_mapping = it->get(); + if (!JsonObjectIterate(model_mapping, [&](std::string key) -> bool { + auto model_json = model_mapping.find(key); + if (!model_json->is_string()) { + LOG_ERROR( + "Invalid type for item in modelMapping. Expected string."); + return false; + } + if (key == "*") { + rule.default_model_mapping_ = model_json->get(); + return true; + } + if (absl::EndsWith(key, "*")) { + rule.prefix_model_mapping_.emplace_back( + absl::StripSuffix(key, "*"), model_json->get()); + return true; + } + auto ret = rule.exact_model_mapping_.emplace( + key, model_json->get()); + if (!ret.second) { + LOG_ERROR("Duplicate key in modelMapping: " + key); + return false; + } + return true; + })) { + return false; + } + } + + if (!JsonArrayIterate( + configuration, "enableOnPathSuffix", [&](const json& item) -> bool { + if (item.is_string()) { + rule.enable_on_path_suffix_.emplace_back(item.get()); + return true; + } + return false; + })) { + LOG_WARN("Invalid type for item in enableOnPathSuffix. Expected string."); + return false; + } + return true; +} + +bool PluginRootContext::onConfigure(size_t size) { + // Parse configuration JSON string. + if (size > 0 && !configure(size)) { + LOG_WARN("configuration has errors initialization will not continue."); + return false; + } + return true; +} + +bool PluginRootContext::configure(size_t configuration_size) { + auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration, + 0, configuration_size); + // Parse configuration JSON string. + auto result = ::Wasm::Common::JsonParse(configuration_data->view()); + if (!result) { + LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", + configuration_data->view())); + return false; + } + if (!parseRuleConfig(result.value())) { + LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", + configuration_data->view())); + return false; + } + return true; +} + +FilterHeadersStatus PluginRootContext::onHeader( + const ModelMapperConfigRule& rule) { + if (!Wasm::Common::Http::hasRequestBody()) { + return FilterHeadersStatus::Continue; + } + auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString(); + auto params_pos = path.find('?'); + size_t uri_end; + if (params_pos == std::string::npos) { + uri_end = path.size(); + } else { + uri_end = params_pos; + } + bool enable = false; + for (const auto& enable_suffix : rule.enable_on_path_suffix_) { + if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) { + enable = true; + break; + } + } + if (!enable) { + return FilterHeadersStatus::Continue; + } + auto content_type_value = + getRequestHeader(Wasm::Common::Http::Header::ContentType); + if (!absl::StrContains(content_type_value->view(), + Wasm::Common::Http::ContentTypeValues::Json)) { + return FilterHeadersStatus::Continue; + } + removeRequestHeader(Wasm::Common::Http::Header::ContentLength); + setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes); + return FilterHeadersStatus::StopIteration; +} + +FilterDataStatus PluginRootContext::onBody(const ModelMapperConfigRule& rule, + std::string_view body) { + const auto& exact_model_mapping = rule.exact_model_mapping_; + const auto& prefix_model_mapping = rule.prefix_model_mapping_; + const auto& default_model_mapping = rule.default_model_mapping_; + const auto& model_key = rule.model_key_; + auto body_json_opt = ::Wasm::Common::JsonParse(body); + if (!body_json_opt) { + LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body)); + return FilterDataStatus::Continue; + } + auto body_json = body_json_opt.value(); + std::string old_model; + if (body_json.contains(model_key)) { + old_model = body_json[model_key]; + } + std::string model = + default_model_mapping.empty() ? old_model : default_model_mapping; + if (auto it = exact_model_mapping.find(old_model); + it != exact_model_mapping.end()) { + model = it->second; + } else { + for (auto& prefix_model_pair : prefix_model_mapping) { + if (absl::StartsWith(old_model, prefix_model_pair.first)) { + model = prefix_model_pair.second; + break; + } + } + } + if (!model.empty() && model != old_model) { + body_json[model_key] = model; + setBuffer(WasmBufferType::HttpRequestBody, 0, + std::numeric_limits::max(), body_json.dump()); + LOG_DEBUG( + absl::StrCat("model mapped, before:", old_model, ", after:", model)); + } + return FilterDataStatus::Continue; +} + +FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) { + auto* rootCtx = rootContext(); + return rootCtx->onHeaders([rootCtx, this](const auto& config) { + auto ret = rootCtx->onHeader(config); + if (ret == FilterHeadersStatus::StopIteration) { + this->config_ = &config; + } + return ret; + }); +} + +FilterDataStatus PluginContext::onRequestBody(size_t body_size, + bool end_stream) { + if (config_ == nullptr) { + return FilterDataStatus::Continue; + } + body_total_size_ += body_size; + if (!end_stream) { + return FilterDataStatus::StopIterationAndBuffer; + } + auto body = + getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_); + auto* rootCtx = rootContext(); + return rootCtx->onBody(*config_, body->view()); +} + +#ifdef NULL_PLUGIN + +} // namespace model_mapper +} // namespace null_plugin +} // namespace proxy_wasm + +#endif diff --git a/plugins/wasm-cpp/extensions/model_mapper/plugin.h b/plugins/wasm-cpp/extensions/model_mapper/plugin.h new file mode 100644 index 0000000000..df0c7eb0e1 --- /dev/null +++ b/plugins/wasm-cpp/extensions/model_mapper/plugin.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2022 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "common/route_rule_matcher.h" +#define ASSERT(_X) assert(_X) + +#ifndef NULL_PLUGIN + +#include "proxy_wasm_intrinsics.h" + +#else + +#include "include/proxy-wasm/null_plugin.h" + +namespace proxy_wasm { +namespace null_plugin { +namespace model_mapper { + +#endif + +struct ModelMapperConfigRule { + std::string model_key_ = "model"; + std::map exact_model_mapping_; + std::vector> prefix_model_mapping_; + std::string default_model_mapping_; + std::vector enable_on_path_suffix_ = {"/v1/chat/completions"}; +}; + +// PluginRootContext is the root context for all streams processed by the +// thread. It has the same lifetime as the worker thread and acts as target for +// interactions that outlives individual stream, e.g. timer, async calls. +class PluginRootContext : public RootContext, + public RouteRuleMatcher { + public: + PluginRootContext(uint32_t id, std::string_view root_id) + : RootContext(id, root_id) {} + ~PluginRootContext() {} + bool onConfigure(size_t) override; + FilterHeadersStatus onHeader(const ModelMapperConfigRule&); + FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view); + bool configure(size_t); + + private: + bool parsePluginConfig(const json&, ModelMapperConfigRule&) override; +}; + +// Per-stream context. +class PluginContext : public Context { + public: + explicit PluginContext(uint32_t id, RootContext* root) : Context(id, root) {} + FilterHeadersStatus onRequestHeaders(uint32_t, bool) override; + FilterDataStatus onRequestBody(size_t, bool) override; + + private: + inline PluginRootContext* rootContext() { + return dynamic_cast(this->root()); + } + + size_t body_total_size_ = 0; + const ModelMapperConfigRule* config_ = nullptr; +}; + +#ifdef NULL_PLUGIN + +} // namespace model_mapper +} // namespace null_plugin +} // namespace proxy_wasm + +#endif diff --git a/plugins/wasm-cpp/extensions/model_mapper/plugin_test.cc b/plugins/wasm-cpp/extensions/model_mapper/plugin_test.cc new file mode 100644 index 0000000000..7f9c706ff9 --- /dev/null +++ b/plugins/wasm-cpp/extensions/model_mapper/plugin_test.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/model_mapper/plugin.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "include/proxy-wasm/context.h" +#include "include/proxy-wasm/null.h" + +namespace proxy_wasm { +namespace null_plugin { +namespace model_mapper { + +NullPluginRegistry* context_registry_; +RegisterNullVmPluginFactory register_model_mapper_plugin("model_mapper", []() { + return std::make_unique(model_mapper::context_registry_); +}); + +class MockContext : public proxy_wasm::ContextBase { + public: + MockContext(WasmBase* wasm) : ContextBase(wasm) {} + MOCK_METHOD(BufferInterface*, getBuffer, (WasmBufferType)); + MOCK_METHOD(WasmResult, log, (uint32_t, std::string_view)); + MOCK_METHOD(WasmResult, setBuffer, + (WasmBufferType, size_t, size_t, std::string_view)); + MOCK_METHOD(WasmResult, getHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* key */, + std::string_view* /*result */)); + MOCK_METHOD(WasmResult, replaceHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* key */, + std::string_view /* value */)); + MOCK_METHOD(WasmResult, removeHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* key */)); + MOCK_METHOD(WasmResult, addHeaderMapValue, + (WasmHeaderMapType, std::string_view, std::string_view)); + MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*)); + MOCK_METHOD(WasmResult, setProperty, (std::string_view, std::string_view)); +}; +class ModelMapperTest : public ::testing::Test { + protected: + ModelMapperTest() { + // Initialize test VM + test_vm_ = createNullVm(); + wasm_base_ = std::make_unique( + std::move(test_vm_), "test-vm", "", "", + std::unordered_map{}, + AllowedCapabilitiesMap{}); + wasm_base_->load("model_mapper"); + wasm_base_->initialize(); + // Initialize host side context + mock_context_ = std::make_unique(wasm_base_.get()); + current_context_ = mock_context_.get(); + // Initialize Wasm sandbox context + root_context_ = std::make_unique(0, ""); + context_ = std::make_unique(1, root_context_.get()); + + ON_CALL(*mock_context_, log(testing::_, testing::_)) + .WillByDefault([](uint32_t, std::string_view m) { + std::cerr << m << "\n"; + return WasmResult::Ok; + }); + + ON_CALL(*mock_context_, getBuffer(testing::_)) + .WillByDefault([&](WasmBufferType type) { + if (type == WasmBufferType::HttpRequestBody) { + return &body_; + } + return &config_; + }); + ON_CALL(*mock_context_, getHeaderMapValue(WasmHeaderMapType::RequestHeaders, + testing::_, testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view header, + std::string_view* result) { + if (header == "content-type") { + *result = "application/json"; + } else if (header == "content-length") { + *result = "1024"; + } else if (header == ":path") { + *result = path_; + } + return WasmResult::Ok; + }); + ON_CALL(*mock_context_, + replaceHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_, + testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view key, + std::string_view value) { return WasmResult::Ok; }); + ON_CALL(*mock_context_, + removeHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view key) { + return WasmResult::Ok; + }); + ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders, + testing::_, testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view header, + std::string_view value) { return WasmResult::Ok; }); + ON_CALL(*mock_context_, getProperty(testing::_, testing::_)) + .WillByDefault([&](std::string_view path, std::string* result) { + if (absl::StartsWith(path, "route_name")) { + *result = route_name_; + } else if (absl::StartsWith(path, "cluster_name")) { + *result = service_name_; + } + return WasmResult::Ok; + }); + ON_CALL(*mock_context_, setProperty(testing::_, testing::_)) + .WillByDefault( + [&](std::string_view, std::string_view) { return WasmResult::Ok; }); + } + ~ModelMapperTest() override {} + std::unique_ptr wasm_base_; + std::unique_ptr test_vm_; + std::unique_ptr mock_context_; + std::unique_ptr root_context_; + std::unique_ptr context_; + std::string route_name_; + std::string service_name_; + std::string path_; + BufferBase body_; + BufferBase config_; +}; + +TEST_F(ModelMapperTest, ModelMappingTest) { + std::string configuration = R"( +{ + "modelMapping": { + "*": "qwen-long", + "gpt-4*": "qwen-max", + "gpt-4o": "qwen-turbo", + "gpt-4o-mini": "qwen-plus", + "text-embedding-v1": "" + } +})"; + + config_.set(configuration); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + path_ = "/v1/chat/completions"; + std::string request_json = R"({"model": "gpt-3.5"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-long"})"); + return WasmResult::Ok; + }); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); + + request_json = R"({"model": "gpt-4"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-max"})"); + return WasmResult::Ok; + }); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); + + request_json = R"({"model": "gpt-4o"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-turbo"})"); + return WasmResult::Ok; + }); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); + + request_json = R"({"model": "gpt-4o-mini"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-plus"})"); + return WasmResult::Ok; + }); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); + + request_json = R"({"model": "text-embedding-v1"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .Times(0); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); + + request_json = R"({})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-long"})"); + return WasmResult::Ok; + }); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); +} + +TEST_F(ModelMapperTest, RouteLevelModelMappingTest) { + std::string configuration = R"( +{ + "_rules_": [ + { + "_match_route_": ["route-a"], + "_match_service_": ["service-1"], + "modelMapping": { + "*": "qwen-long" + } + }, + { + "_match_route_": ["route-b"], + "_match_service_": ["service-2"], + "modelMapping": { + "*": "qwen-max" + } + }, + { + "_match_route_": ["route-b"], + "_match_service_": ["service-3"], + "modelMapping": { + "*": "qwen-turbo" + } + } +]})"; + + config_.set(configuration); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + path_ = "/api/v1/chat/completions"; + std::string request_json = R"({"model": "gpt-4"})"; + body_.set(request_json); + route_name_ = "route-a"; + service_name_ = "outbound|80||service-1"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-long"})"); + return WasmResult::Ok; + }); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); + + route_name_ = "route-b"; + service_name_ = "outbound|80||service-2"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-max"})"); + return WasmResult::Ok; + }); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); + + route_name_ = "route-b"; + service_name_ = "outbound|80||service-3"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-turbo"})"); + return WasmResult::Ok; + }); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue); +} + +} // namespace model_mapper +} // namespace null_plugin +} // namespace proxy_wasm diff --git a/plugins/wasm-cpp/extensions/model_router/README.md b/plugins/wasm-cpp/extensions/model_router/README.md index b63be35d8f..e78988e0e0 100644 --- a/plugins/wasm-cpp/extensions/model_router/README.md +++ b/plugins/wasm-cpp/extensions/model_router/README.md @@ -1,33 +1,67 @@ ## 功能说明 `model-router`插件实现了基于LLM协议中的model参数路由的功能 +## 配置字段 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | +| `modelKey` | string | 选填 | model | 请求body中model参数的位置 | +| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 | +| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 | +| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 | + ## 运行属性 -插件执行阶段:`默认阶段` -插件执行优先级:`260` +插件执行阶段:认证阶段 +插件执行优先级:900 -## 配置字段 +## 效果说明 -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | -| `enable` | bool | 选填 | false | 是否开启基于model参数路由 | -| `model_key` | string | 选填 | model | 请求body中model参数的位置 | -| `add_header_key` | string | 选填 | x-higress-llm-provider | 从model参数中解析出的provider名字放到哪个请求header中 | +### 基于 model 参数进行路由 +需要做如下配置: -## 效果说明 +```yaml +modelToHeader: x-higress-llm-model +``` + +插件会将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是: + +```json +{ + "model": "qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + +经过这个插件后,将添加下面这个请求头(可以用于路由匹配): + +x-higress-llm-model: qwen-long + +### 提取 model 参数中的 provider 字段用于路由 + +> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider -如下开启基于model参数路由的功能: +需要做如下配置: ```yaml -enable: true +addProviderHeader: x-higress-llm-provider ``` -开启后,插件将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是: +插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是: ```json { - "model": "qwen/qwen-long", + "model": "dashscope/qwen-long", "frequency_penalty": 0, "max_tokens": 800, "stream": false, @@ -43,7 +77,7 @@ enable: true 经过这个插件后,将添加下面这个请求头(可以用于路由匹配): -x-higress-llm-provider: qwen +x-higress-llm-provider: dashscope 原始的 LLM 请求体将被改成: diff --git a/plugins/wasm-cpp/extensions/model_router/README_EN.md b/plugins/wasm-cpp/extensions/model_router/README_EN.md index 4d2eaf1fee..f5866423b3 100644 --- a/plugins/wasm-cpp/extensions/model_router/README_EN.md +++ b/plugins/wasm-cpp/extensions/model_router/README_EN.md @@ -1,38 +1,41 @@ ## Function Description -The `model-router` plugin implements the functionality of routing based on the `model` parameter in the LLM protocol. +The `model-router` plugin implements the function of routing based on the model parameter in the LLM protocol. -## Runtime Properties +## Configuration Fields -Plugin Execution Phase: `Default Phase` -Plugin Execution Priority: `260` +| Name | Data Type | Filling Requirement | Default Value | Description | +| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | +| `modelKey` | string | Optional | model | The location of the model parameter in the request body | +| `addProviderHeader` | string | Optional | - | Which request header to place the provider name parsed from the model parameter | +| `modelToHeader` | string | Optional | - | Which request header to directly place the model parameter | +| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only effective for requests with these specific path suffixes | -## Configuration Fields +## Runtime Attributes -| Name | Data Type | Filling Requirement | Default Value | Description | -| -------------------- | ------------- | --------------------- | ---------------------- | ----------------------------------------------------- | -| `enable` | bool | Optional | false | Whether to enable routing based on the `model` parameter | -| `model_key` | string | Optional | model | The location of the `model` parameter in the request body | -| `add_header_key` | string | Optional | x-higress-llm-provider | The header where the parsed provider name from the `model` parameter will be placed | +Plugin execution phase: Authentication phase +Plugin execution priority: 900 ## Effect Description -To enable routing based on the `model` parameter, use the following configuration: +### Routing Based on the model Parameter + +The following configuration is required: ```yaml -enable: true +modelToHeader: x-higress-llm-model ``` -After enabling, the plugin extracts the provider part (if any) from the `model` parameter in the request, and sets it in the `x-higress-llm-provider` request header for subsequent routing. It also rewrites the `model` parameter to the model name part. For example, the original LLM request body is: +The plugin will extract the model parameter from the request and set it in the x-higress-llm-model request header, which can be used for subsequent routing. For example, the original LLM request body: ```json { - "model": "openai/gpt-4o", + "model": "qwen-long", "frequency_penalty": 0, "max_tokens": 800, "stream": false, "messages": [{ "role": "user", - "content": "What is the GitHub address for the main repository of the Higress project?" + "content": "What is the GitHub address of the main repository for the higress project" }], "presence_penalty": 0, "temperature": 0.7, @@ -40,24 +43,55 @@ After enabling, the plugin extracts the provider part (if any) from the `model` } ``` -After processing by the plugin, the following request header (which can be used for routing matching) will be added: +After processing by this plugin, the following request header (which can be used for route matching) will be added: + +x-higress-llm-model: qwen-long -`x-higress-llm-provider: openai` +### Extracting the provider Field from the model Parameter for Routing -The original LLM request body will be modified to: +> Note that this mode requires the client to specify the provider using a `/` separator in the model parameter. + +The following configuration is required: + +```yaml +addProviderHeader: x-higress-llm-provider +``` + +The plugin will extract the provider part (if present) from the model parameter in the request and set it in the x-higress-llm-provider request header, which can be used for subsequent routing, and rewrite the model parameter to the model name part. For example, the original LLM request body: ```json { - "model": "gpt-4o", + "model": "dashscope/qwen-long", "frequency_penalty": 0, "max_tokens": 800, "stream": false, "messages": [{ "role": "user", - "content": "What is the GitHub address for the main repository of the Higress project?" + "content": "What is the GitHub address of the main repository for the higress project" }], "presence_penalty": 0, "temperature": 0.7, "top_p": 0.95 } ``` + +After processing by this plugin, the following request header (which can be used for route matching) will be added: + +x-higress-llm-provider: dashscope + +The original LLM request body will be changed to: + +```json +{ + "model": "qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the GitHub address of the main repository for the higress project" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} diff --git a/plugins/wasm-cpp/extensions/model_router/plugin.cc b/plugins/wasm-cpp/extensions/model_router/plugin.cc index 66a90973ff..0e45ec473f 100644 --- a/plugins/wasm-cpp/extensions/model_router/plugin.cc +++ b/plugins/wasm-cpp/extensions/model_router/plugin.cc @@ -51,41 +51,54 @@ constexpr std::string_view DefaultMaxBodyBytes = "10485760"; bool PluginRootContext::parsePluginConfig(const json& configuration, ModelRouterConfigRule& rule) { - if (auto it = configuration.find("enable"); it != configuration.end()) { - if (it->is_boolean()) { - rule.enable_ = it->get(); + if (auto it = configuration.find("modelKey"); it != configuration.end()) { + if (it->is_string()) { + rule.model_key_ = it->get(); } else { - LOG_WARN("Invalid type for enable. Expected boolean."); + LOG_ERROR("Invalid type for modelKey. Expected string."); return false; } } - if (auto it = configuration.find("model_key"); it != configuration.end()) { + if (auto it = configuration.find("addProviderHeader"); + it != configuration.end()) { if (it->is_string()) { - rule.model_key_ = it->get(); + rule.add_provider_header_ = it->get(); } else { - LOG_WARN("Invalid type for model_key. Expected string."); + LOG_ERROR("Invalid type for addProviderHeader. Expected string."); return false; } } - if (auto it = configuration.find("add_header_key"); + if (auto it = configuration.find("modelToHeader"); it != configuration.end()) { if (it->is_string()) { - rule.add_header_key_ = it->get(); + rule.model_to_header_ = it->get(); } else { - LOG_WARN("Invalid type for add_header_key. Expected string."); + LOG_ERROR("Invalid type for modelToHeader. Expected string."); return false; } } + if (!JsonArrayIterate( + configuration, "enableOnPathSuffix", [&](const json& item) -> bool { + if (item.is_string()) { + rule.enable_on_path_suffix_.emplace_back(item.get()); + return true; + } + return false; + })) { + LOG_ERROR("Invalid type for item in enableOnPathSuffix. Expected string."); + return false; + } + return true; } bool PluginRootContext::onConfigure(size_t size) { // Parse configuration JSON string. if (size > 0 && !configure(size)) { - LOG_WARN("configuration has errors initialization will not continue."); + LOG_ERROR("configuration has errors initialization will not continue."); return false; } return true; @@ -97,13 +110,13 @@ bool PluginRootContext::configure(size_t configuration_size) { // Parse configuration JSON string. auto result = ::Wasm::Common::JsonParse(configuration_data->view()); if (!result) { - LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", - configuration_data->view())); + LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ", + configuration_data->view())); return false; } if (!parseRuleConfig(result.value())) { - LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", - configuration_data->view())); + LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ", + configuration_data->view())); return false; } return true; @@ -111,7 +124,25 @@ bool PluginRootContext::configure(size_t configuration_size) { FilterHeadersStatus PluginRootContext::onHeader( const ModelRouterConfigRule& rule) { - if (!rule.enable_ || !Wasm::Common::Http::hasRequestBody()) { + if (!Wasm::Common::Http::hasRequestBody()) { + return FilterHeadersStatus::Continue; + } + auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString(); + auto params_pos = path.find('?'); + size_t uri_end; + if (params_pos == std::string::npos) { + uri_end = path.size(); + } else { + uri_end = params_pos; + } + bool enable = false; + for (const auto& enable_suffix : rule.enable_on_path_suffix_) { + if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) { + enable = true; + break; + } + } + if (!enable) { return FilterHeadersStatus::Continue; } auto content_type_value = @@ -128,7 +159,8 @@ FilterHeadersStatus PluginRootContext::onHeader( FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule, std::string_view body) { const auto& model_key = rule.model_key_; - const auto& add_header_key = rule.add_header_key_; + const auto& add_provider_header = rule.add_provider_header_; + const auto& model_to_header = rule.model_to_header_; auto body_json_opt = ::Wasm::Common::JsonParse(body); if (!body_json_opt) { LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body)); @@ -137,18 +169,24 @@ FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule, auto body_json = body_json_opt.value(); if (body_json.contains(model_key)) { std::string model_value = body_json[model_key]; - auto pos = model_value.find('/'); - if (pos != std::string::npos) { - const auto& provider = model_value.substr(0, pos); - const auto& model = model_value.substr(pos + 1); - replaceRequestHeader(add_header_key, provider); - body_json[model_key] = model; - setBuffer(WasmBufferType::HttpRequestBody, 0, - std::numeric_limits::max(), body_json.dump()); - LOG_DEBUG(absl::StrCat("model route to provider:", provider, - ", model:", model)); - } else { - LOG_DEBUG(absl::StrCat("model route not work, model:", model_value)); + if (!model_to_header.empty()) { + replaceRequestHeader(model_to_header, model_value); + } + if (!add_provider_header.empty()) { + auto pos = model_value.find('/'); + if (pos != std::string::npos) { + const auto& provider = model_value.substr(0, pos); + const auto& model = model_value.substr(pos + 1); + replaceRequestHeader(add_provider_header, provider); + body_json[model_key] = model; + setBuffer(WasmBufferType::HttpRequestBody, 0, + std::numeric_limits::max(), body_json.dump()); + LOG_DEBUG(absl::StrCat("model route to provider:", provider, + ", model:", model)); + } else { + LOG_DEBUG(absl::StrCat("model route to provider not work, model:", + model_value)); + } } } return FilterDataStatus::Continue; diff --git a/plugins/wasm-cpp/extensions/model_router/plugin.h b/plugins/wasm-cpp/extensions/model_router/plugin.h index 16cfdf8509..14ef0b632a 100644 --- a/plugins/wasm-cpp/extensions/model_router/plugin.h +++ b/plugins/wasm-cpp/extensions/model_router/plugin.h @@ -37,9 +37,10 @@ namespace model_router { #endif struct ModelRouterConfigRule { - bool enable_ = false; std::string model_key_ = "model"; - std::string add_header_key_ = "x-higress-llm-provider"; + std::string add_provider_header_; + std::string model_to_header_; + std::vector enable_on_path_suffix_ = {"/v1/chat/completions"}; }; // PluginRootContext is the root context for all streams processed by the diff --git a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc index dc351ecdc8..1204e16b31 100644 --- a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc +++ b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc @@ -89,6 +89,8 @@ class ModelRouterTest : public ::testing::Test { *result = "application/json"; } else if (header == "content-length") { *result = "1024"; + } else if (header == ":path") { + *result = path_; } return WasmResult::Ok; }); @@ -122,6 +124,7 @@ class ModelRouterTest : public ::testing::Test { std::unique_ptr root_context_; std::unique_ptr context_; std::string route_name_; + std::string path_; BufferBase body_; BufferBase config_; }; @@ -129,12 +132,13 @@ class ModelRouterTest : public ::testing::Test { TEST_F(ModelRouterTest, RewriteModelAndHeader) { std::string configuration = R"( { - "enable": true + "addProviderHeader": "x-higress-llm-provider" })"; config_.set(configuration); EXPECT_TRUE(root_context_->configure(configuration.size())); + path_ = "/v1/chat/completions"; std::string request_json = R"({"model": "qwen/qwen-long"})"; EXPECT_CALL(*mock_context_, setBuffer(testing::_, testing::_, testing::_, testing::_)) @@ -154,19 +158,73 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) { EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue); } +TEST_F(ModelRouterTest, ModelToHeader) { + std::string configuration = R"( +{ + "modelToHeader": "x-higress-llm-model" + })"; + + config_.set(configuration); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + path_ = "/v1/chat/completions"; + std::string request_json = R"({"model": "qwen-long"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .Times(0); + + EXPECT_CALL( + *mock_context_, + replaceHeaderMapValue(testing::_, std::string_view("x-higress-llm-model"), + std::string_view("qwen-long"))); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue); +} + +TEST_F(ModelRouterTest, IgnorePath) { + std::string configuration = R"( +{ + "addProviderHeader": "x-higress-llm-provider" + })"; + + config_.set(configuration); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + path_ = "/v1/chat/xxxx"; + std::string request_json = R"({"model": "qwen/qwen-long"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .Times(0); + + EXPECT_CALL(*mock_context_, + replaceHeaderMapValue(testing::_, + std::string_view("x-higress-llm-provider"), + std::string_view("qwen"))) + .Times(0); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); + EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue); +} + TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) { std::string configuration = R"( { "_rules_": [ { "_match_route_": ["route-a"], - "enable": true + "addProviderHeader": "x-higress-llm-provider" } ]})"; config_.set(configuration); EXPECT_TRUE(root_context_->configure(configuration.size())); + path_ = "/api/v1/chat/completions"; std::string request_json = R"({"model": "qwen/qwen-long"})"; EXPECT_CALL(*mock_context_, setBuffer(testing::_, testing::_, testing::_, testing::_))