diff --git a/plugins/wasm-cpp/extensions/model_mapper/BUILD b/plugins/wasm-cpp/extensions/model_mapper/BUILD
new file mode 100644
index 0000000000..da489367c7
--- /dev/null
+++ b/plugins/wasm-cpp/extensions/model_mapper/BUILD
@@ -0,0 +1,70 @@
+# Copyright (c) 2022 Alibaba Group Holding Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
+load("//bazel:wasm.bzl", "declare_wasm_image_targets")
+
+proxy_wasm_cc_binary(
+ name = "model_mapper.wasm",
+ srcs = [
+ "plugin.cc",
+ "plugin.h",
+ ],
+ deps = [
+ "@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics_higress",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "//common:json_util",
+ "//common:http_util",
+ "//common:rule_util",
+ ],
+)
+
+cc_library(
+ name = "model_mapper_lib",
+ srcs = [
+ "plugin.cc",
+ ],
+ hdrs = [
+ "plugin.h",
+ ],
+ copts = ["-DNULL_PLUGIN"],
+ deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "//common:json_util",
+ "@proxy_wasm_cpp_host//:lib",
+ "//common:http_util_nullvm",
+ "//common:rule_util_nullvm",
+ ],
+)
+
+cc_test(
+ name = "model_mapper_test",
+ srcs = [
+ "plugin_test.cc",
+ ],
+ copts = ["-DNULL_PLUGIN"],
+ deps = [
+ ":model_mapper_lib",
+ "@com_google_googletest//:gtest",
+ "@com_google_googletest//:gtest_main",
+ "@proxy_wasm_cpp_host//:lib",
+ ],
+)
+
+declare_wasm_image_targets(
+ name = "model_mapper",
+ wasm_file = ":model_mapper.wasm",
+)
diff --git a/plugins/wasm-cpp/extensions/model_mapper/README.md b/plugins/wasm-cpp/extensions/model_mapper/README.md
new file mode 100644
index 0000000000..1e9fc747f7
--- /dev/null
+++ b/plugins/wasm-cpp/extensions/model_mapper/README.md
@@ -0,0 +1,63 @@
+## 功能说明
+`model-mapper`插件实现了基于LLM协议中的model参数路由的功能
+
+## 配置字段
+
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
+| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
+| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
+| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 ## 运行属性
+
+插件执行阶段:认证阶段
+插件执行优先级:800
+ |
+## 效果说明
+
+如下配置
+
+```yaml
+modelMapping:
+ 'gpt-4-*': "qwen-max"
+ 'gpt-4o': "qwen-vl-plus"
+ '*': "qwen-turbo"
+```
+
+开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo`
+
+例如原本的请求是:
+
+```json
+{
+ "model": "gpt-4o",
+ "frequency_penalty": 0,
+ "max_tokens": 800,
+ "stream": false,
+ "messages": [{
+ "role": "user",
+ "content": "higress项目主仓库的github地址是什么"
+ }],
+ "presence_penalty": 0,
+ "temperature": 0.7,
+ "top_p": 0.95
+}
+```
+
+
+经过这个插件后,原始的 LLM 请求体将被改成:
+
+```json
+{
+ "model": "qwen-vl-plus",
+ "frequency_penalty": 0,
+ "max_tokens": 800,
+ "stream": false,
+ "messages": [{
+ "role": "user",
+ "content": "higress项目主仓库的github地址是什么"
+ }],
+ "presence_penalty": 0,
+ "temperature": 0.7,
+ "top_p": 0.95
+}
+```
diff --git a/plugins/wasm-cpp/extensions/model_mapper/README_EN.md b/plugins/wasm-cpp/extensions/model_mapper/README_EN.md
new file mode 100644
index 0000000000..f38cc84f68
--- /dev/null
+++ b/plugins/wasm-cpp/extensions/model_mapper/README_EN.md
@@ -0,0 +1,65 @@
+## Function Description
+The `model-mapper` plugin implements the functionality of routing based on the model parameter in the LLM protocol.
+
+## Configuration Fields
+
+| Name | Data Type | Filling Requirement | Default Value | Description |
+| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
+| `modelKey` | string | Optional | model | The location of the model parameter in the request body. |
+| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model names in the request to the model names supported by the service provider.
1. Supports prefix matching. For example, use "gpt-3-*" to match all models whose names start with “gpt-3-”;
2. Supports using "*" as the key to configure a generic fallback mapping relationship;
3. If the target name in the mapping is an empty string "", it means to keep the original model name. |
+| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only applies to requests with these specific path suffixes. |
+
+## Runtime Properties
+
+Plugin execution phase: Authentication phase
+Plugin execution priority: 800
+
+## Effect Description
+
+With the following configuration:
+
+```yaml
+modelMapping:
+ 'gpt-4-*': "qwen-max"
+ 'gpt-4o': "qwen-vl-plus"
+ '*': "qwen-turbo"
+```
+
+After enabling, model parameters starting with `gpt-4-` will be rewritten to `qwen-max`, `gpt-4o` will be rewritten to `qwen-vl-plus`, and all other models will be rewritten to `qwen-turbo`.
+
+For example, if the original request was:
+
+```json
+{
+ "model": "gpt-4o",
+ "frequency_penalty": 0,
+ "max_tokens": 800,
+ "stream": false,
+ "messages": [{
+ "role": "user",
+ "content": "What is the GitHub address of the main repository for the higress project?"
+ }],
+ "presence_penalty": 0,
+ "temperature": 0.7,
+ "top_p": 0.95
+}
+```
+
+
+After processing by this plugin, the original LLM request body will be modified to:
+
+```json
+{
+ "model": "qwen-vl-plus",
+ "frequency_penalty": 0,
+ "max_tokens": 800,
+ "stream": false,
+ "messages": [{
+ "role": "user",
+ "content": "What is the GitHub address of the main repository for the higress project?"
+ }],
+ "presence_penalty": 0,
+ "temperature": 0.7,
+ "top_p": 0.95
+}
+```
diff --git a/plugins/wasm-cpp/extensions/model_mapper/plugin.cc b/plugins/wasm-cpp/extensions/model_mapper/plugin.cc
new file mode 100644
index 0000000000..69c2e7c7c5
--- /dev/null
+++ b/plugins/wasm-cpp/extensions/model_mapper/plugin.cc
@@ -0,0 +1,243 @@
+// Copyright (c) 2022 Alibaba Group Holding Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "extensions/model_mapper/plugin.h"
+
+#include
+#include
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "common/http_util.h"
+#include "common/json_util.h"
+
+using ::nlohmann::json;
+using ::Wasm::Common::JsonArrayIterate;
+using ::Wasm::Common::JsonGetField;
+using ::Wasm::Common::JsonObjectIterate;
+using ::Wasm::Common::JsonValueAs;
+
+#ifdef NULL_PLUGIN
+
+namespace proxy_wasm {
+namespace null_plugin {
+namespace model_mapper {
+
+PROXY_WASM_NULL_PLUGIN_REGISTRY
+
+#endif
+
+static RegisterContextFactory register_ModelMapper(
+ CONTEXT_FACTORY(PluginContext), ROOT_FACTORY(PluginRootContext));
+
+namespace {
+
+constexpr std::string_view SetDecoderBufferLimitKey =
+ "SetRequestBodyBufferLimit";
+constexpr std::string_view DefaultMaxBodyBytes = "10485760";
+
+} // namespace
+
+bool PluginRootContext::parsePluginConfig(const json& configuration,
+ ModelMapperConfigRule& rule) {
+ if (auto it = configuration.find("modelKey"); it != configuration.end()) {
+ if (it->is_string()) {
+ rule.model_key_ = it->get();
+ } else {
+ LOG_ERROR("Invalid type for modelKey. Expected string.");
+ return false;
+ }
+ }
+
+ if (auto it = configuration.find("modelMapping"); it != configuration.end()) {
+ if (!it->is_object()) {
+ LOG_ERROR("Invalid type for modelMapping. Expected object.");
+ return false;
+ }
+ auto model_mapping = it->get();
+ if (!JsonObjectIterate(model_mapping, [&](std::string key) -> bool {
+ auto model_json = model_mapping.find(key);
+ if (!model_json->is_string()) {
+ LOG_ERROR(
+ "Invalid type for item in modelMapping. Expected string.");
+ return false;
+ }
+ if (key == "*") {
+ rule.default_model_mapping_ = model_json->get();
+ return true;
+ }
+ if (absl::EndsWith(key, "*")) {
+ rule.prefix_model_mapping_.emplace_back(
+ absl::StripSuffix(key, "*"), model_json->get());
+ return true;
+ }
+ auto ret = rule.exact_model_mapping_.emplace(
+ key, model_json->get());
+ if (!ret.second) {
+ LOG_ERROR("Duplicate key in modelMapping: " + key);
+ return false;
+ }
+ return true;
+ })) {
+ return false;
+ }
+ }
+
+ if (!JsonArrayIterate(
+ configuration, "enableOnPathSuffix", [&](const json& item) -> bool {
+ if (item.is_string()) {
+ rule.enable_on_path_suffix_.emplace_back(item.get());
+ return true;
+ }
+ return false;
+ })) {
+ LOG_WARN("Invalid type for item in enableOnPathSuffix. Expected string.");
+ return false;
+ }
+ return true;
+}
+
+bool PluginRootContext::onConfigure(size_t size) {
+ // Parse configuration JSON string.
+ if (size > 0 && !configure(size)) {
+ LOG_WARN("configuration has errors initialization will not continue.");
+ return false;
+ }
+ return true;
+}
+
+bool PluginRootContext::configure(size_t configuration_size) {
+ auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration,
+ 0, configuration_size);
+ // Parse configuration JSON string.
+ auto result = ::Wasm::Common::JsonParse(configuration_data->view());
+ if (!result) {
+ LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
+ configuration_data->view()));
+ return false;
+ }
+ if (!parseRuleConfig(result.value())) {
+ LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
+ configuration_data->view()));
+ return false;
+ }
+ return true;
+}
+
+FilterHeadersStatus PluginRootContext::onHeader(
+ const ModelMapperConfigRule& rule) {
+ if (!Wasm::Common::Http::hasRequestBody()) {
+ return FilterHeadersStatus::Continue;
+ }
+ auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString();
+ auto params_pos = path.find('?');
+ size_t uri_end;
+ if (params_pos == std::string::npos) {
+ uri_end = path.size();
+ } else {
+ uri_end = params_pos;
+ }
+ bool enable = false;
+ for (const auto& enable_suffix : rule.enable_on_path_suffix_) {
+ if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) {
+ enable = true;
+ break;
+ }
+ }
+ if (!enable) {
+ return FilterHeadersStatus::Continue;
+ }
+ auto content_type_value =
+ getRequestHeader(Wasm::Common::Http::Header::ContentType);
+ if (!absl::StrContains(content_type_value->view(),
+ Wasm::Common::Http::ContentTypeValues::Json)) {
+ return FilterHeadersStatus::Continue;
+ }
+ removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
+ setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
+ return FilterHeadersStatus::StopIteration;
+}
+
+FilterDataStatus PluginRootContext::onBody(const ModelMapperConfigRule& rule,
+ std::string_view body) {
+ const auto& exact_model_mapping = rule.exact_model_mapping_;
+ const auto& prefix_model_mapping = rule.prefix_model_mapping_;
+ const auto& default_model_mapping = rule.default_model_mapping_;
+ const auto& model_key = rule.model_key_;
+ auto body_json_opt = ::Wasm::Common::JsonParse(body);
+ if (!body_json_opt) {
+ LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
+ return FilterDataStatus::Continue;
+ }
+ auto body_json = body_json_opt.value();
+ std::string old_model;
+ if (body_json.contains(model_key)) {
+ old_model = body_json[model_key];
+ }
+ std::string model =
+ default_model_mapping.empty() ? old_model : default_model_mapping;
+ if (auto it = exact_model_mapping.find(old_model);
+ it != exact_model_mapping.end()) {
+ model = it->second;
+ } else {
+ for (auto& prefix_model_pair : prefix_model_mapping) {
+ if (absl::StartsWith(old_model, prefix_model_pair.first)) {
+ model = prefix_model_pair.second;
+ break;
+ }
+ }
+ }
+ if (!model.empty() && model != old_model) {
+ body_json[model_key] = model;
+ setBuffer(WasmBufferType::HttpRequestBody, 0,
+ std::numeric_limits::max(), body_json.dump());
+ LOG_DEBUG(
+ absl::StrCat("model mapped, before:", old_model, ", after:", model));
+ }
+ return FilterDataStatus::Continue;
+}
+
+FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) {
+ auto* rootCtx = rootContext();
+ return rootCtx->onHeaders([rootCtx, this](const auto& config) {
+ auto ret = rootCtx->onHeader(config);
+ if (ret == FilterHeadersStatus::StopIteration) {
+ this->config_ = &config;
+ }
+ return ret;
+ });
+}
+
+FilterDataStatus PluginContext::onRequestBody(size_t body_size,
+ bool end_stream) {
+ if (config_ == nullptr) {
+ return FilterDataStatus::Continue;
+ }
+ body_total_size_ += body_size;
+ if (!end_stream) {
+ return FilterDataStatus::StopIterationAndBuffer;
+ }
+ auto body =
+ getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
+ auto* rootCtx = rootContext();
+ return rootCtx->onBody(*config_, body->view());
+}
+
+#ifdef NULL_PLUGIN
+
+} // namespace model_mapper
+} // namespace null_plugin
+} // namespace proxy_wasm
+
+#endif
diff --git a/plugins/wasm-cpp/extensions/model_mapper/plugin.h b/plugins/wasm-cpp/extensions/model_mapper/plugin.h
new file mode 100644
index 0000000000..df0c7eb0e1
--- /dev/null
+++ b/plugins/wasm-cpp/extensions/model_mapper/plugin.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2022 Alibaba Group Holding Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+
+#include
+#include
+
+#include "common/route_rule_matcher.h"
+#define ASSERT(_X) assert(_X)
+
+#ifndef NULL_PLUGIN
+
+#include "proxy_wasm_intrinsics.h"
+
+#else
+
+#include "include/proxy-wasm/null_plugin.h"
+
+namespace proxy_wasm {
+namespace null_plugin {
+namespace model_mapper {
+
+#endif
+
+struct ModelMapperConfigRule {
+ std::string model_key_ = "model";
+ std::map exact_model_mapping_;
+ std::vector> prefix_model_mapping_;
+ std::string default_model_mapping_;
+ std::vector enable_on_path_suffix_ = {"/v1/chat/completions"};
+};
+
+// PluginRootContext is the root context for all streams processed by the
+// thread. It has the same lifetime as the worker thread and acts as target for
+// interactions that outlives individual stream, e.g. timer, async calls.
+class PluginRootContext : public RootContext,
+ public RouteRuleMatcher {
+ public:
+ PluginRootContext(uint32_t id, std::string_view root_id)
+ : RootContext(id, root_id) {}
+ ~PluginRootContext() {}
+ bool onConfigure(size_t) override;
+ FilterHeadersStatus onHeader(const ModelMapperConfigRule&);
+ FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view);
+ bool configure(size_t);
+
+ private:
+ bool parsePluginConfig(const json&, ModelMapperConfigRule&) override;
+};
+
+// Per-stream context.
+class PluginContext : public Context {
+ public:
+ explicit PluginContext(uint32_t id, RootContext* root) : Context(id, root) {}
+ FilterHeadersStatus onRequestHeaders(uint32_t, bool) override;
+ FilterDataStatus onRequestBody(size_t, bool) override;
+
+ private:
+ inline PluginRootContext* rootContext() {
+ return dynamic_cast(this->root());
+ }
+
+ size_t body_total_size_ = 0;
+ const ModelMapperConfigRule* config_ = nullptr;
+};
+
+#ifdef NULL_PLUGIN
+
+} // namespace model_mapper
+} // namespace null_plugin
+} // namespace proxy_wasm
+
+#endif
diff --git a/plugins/wasm-cpp/extensions/model_mapper/plugin_test.cc b/plugins/wasm-cpp/extensions/model_mapper/plugin_test.cc
new file mode 100644
index 0000000000..7f9c706ff9
--- /dev/null
+++ b/plugins/wasm-cpp/extensions/model_mapper/plugin_test.cc
@@ -0,0 +1,301 @@
+// Copyright (c) 2022 Alibaba Group Holding Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "extensions/model_mapper/plugin.h"
+
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "include/proxy-wasm/context.h"
+#include "include/proxy-wasm/null.h"
+
+namespace proxy_wasm {
+namespace null_plugin {
+namespace model_mapper {
+
+NullPluginRegistry* context_registry_;
+RegisterNullVmPluginFactory register_model_mapper_plugin("model_mapper", []() {
+ return std::make_unique(model_mapper::context_registry_);
+});
+
+class MockContext : public proxy_wasm::ContextBase {
+ public:
+ MockContext(WasmBase* wasm) : ContextBase(wasm) {}
+ MOCK_METHOD(BufferInterface*, getBuffer, (WasmBufferType));
+ MOCK_METHOD(WasmResult, log, (uint32_t, std::string_view));
+ MOCK_METHOD(WasmResult, setBuffer,
+ (WasmBufferType, size_t, size_t, std::string_view));
+ MOCK_METHOD(WasmResult, getHeaderMapValue,
+ (WasmHeaderMapType /* type */, std::string_view /* key */,
+ std::string_view* /*result */));
+ MOCK_METHOD(WasmResult, replaceHeaderMapValue,
+ (WasmHeaderMapType /* type */, std::string_view /* key */,
+ std::string_view /* value */));
+ MOCK_METHOD(WasmResult, removeHeaderMapValue,
+ (WasmHeaderMapType /* type */, std::string_view /* key */));
+ MOCK_METHOD(WasmResult, addHeaderMapValue,
+ (WasmHeaderMapType, std::string_view, std::string_view));
+ MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*));
+ MOCK_METHOD(WasmResult, setProperty, (std::string_view, std::string_view));
+};
+class ModelMapperTest : public ::testing::Test {
+ protected:
+ ModelMapperTest() {
+ // Initialize test VM
+ test_vm_ = createNullVm();
+ wasm_base_ = std::make_unique(
+ std::move(test_vm_), "test-vm", "", "",
+ std::unordered_map{},
+ AllowedCapabilitiesMap{});
+ wasm_base_->load("model_mapper");
+ wasm_base_->initialize();
+ // Initialize host side context
+ mock_context_ = std::make_unique(wasm_base_.get());
+ current_context_ = mock_context_.get();
+ // Initialize Wasm sandbox context
+ root_context_ = std::make_unique(0, "");
+ context_ = std::make_unique(1, root_context_.get());
+
+ ON_CALL(*mock_context_, log(testing::_, testing::_))
+ .WillByDefault([](uint32_t, std::string_view m) {
+ std::cerr << m << "\n";
+ return WasmResult::Ok;
+ });
+
+ ON_CALL(*mock_context_, getBuffer(testing::_))
+ .WillByDefault([&](WasmBufferType type) {
+ if (type == WasmBufferType::HttpRequestBody) {
+ return &body_;
+ }
+ return &config_;
+ });
+ ON_CALL(*mock_context_, getHeaderMapValue(WasmHeaderMapType::RequestHeaders,
+ testing::_, testing::_))
+ .WillByDefault([&](WasmHeaderMapType, std::string_view header,
+ std::string_view* result) {
+ if (header == "content-type") {
+ *result = "application/json";
+ } else if (header == "content-length") {
+ *result = "1024";
+ } else if (header == ":path") {
+ *result = path_;
+ }
+ return WasmResult::Ok;
+ });
+ ON_CALL(*mock_context_,
+ replaceHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_,
+ testing::_))
+ .WillByDefault([&](WasmHeaderMapType, std::string_view key,
+ std::string_view value) { return WasmResult::Ok; });
+ ON_CALL(*mock_context_,
+ removeHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_))
+ .WillByDefault([&](WasmHeaderMapType, std::string_view key) {
+ return WasmResult::Ok;
+ });
+ ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders,
+ testing::_, testing::_))
+ .WillByDefault([&](WasmHeaderMapType, std::string_view header,
+ std::string_view value) { return WasmResult::Ok; });
+ ON_CALL(*mock_context_, getProperty(testing::_, testing::_))
+ .WillByDefault([&](std::string_view path, std::string* result) {
+ if (absl::StartsWith(path, "route_name")) {
+ *result = route_name_;
+ } else if (absl::StartsWith(path, "cluster_name")) {
+ *result = service_name_;
+ }
+ return WasmResult::Ok;
+ });
+ ON_CALL(*mock_context_, setProperty(testing::_, testing::_))
+ .WillByDefault(
+ [&](std::string_view, std::string_view) { return WasmResult::Ok; });
+ }
+ ~ModelMapperTest() override {}
+ std::unique_ptr wasm_base_;
+ std::unique_ptr test_vm_;
+ std::unique_ptr mock_context_;
+ std::unique_ptr root_context_;
+ std::unique_ptr context_;
+ std::string route_name_;
+ std::string service_name_;
+ std::string path_;
+ BufferBase body_;
+ BufferBase config_;
+};
+
+TEST_F(ModelMapperTest, ModelMappingTest) {
+ std::string configuration = R"(
+{
+ "modelMapping": {
+ "*": "qwen-long",
+ "gpt-4*": "qwen-max",
+ "gpt-4o": "qwen-turbo",
+ "gpt-4o-mini": "qwen-plus",
+ "text-embedding-v1": ""
+ }
+})";
+
+ config_.set(configuration);
+ EXPECT_TRUE(root_context_->configure(configuration.size()));
+
+ path_ = "/v1/chat/completions";
+ std::string request_json = R"({"model": "gpt-3.5"})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-long"})");
+ return WasmResult::Ok;
+ });
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+
+ request_json = R"({"model": "gpt-4"})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-max"})");
+ return WasmResult::Ok;
+ });
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+
+ request_json = R"({"model": "gpt-4o"})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-turbo"})");
+ return WasmResult::Ok;
+ });
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+
+ request_json = R"({"model": "gpt-4o-mini"})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-plus"})");
+ return WasmResult::Ok;
+ });
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+
+ request_json = R"({"model": "text-embedding-v1"})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .Times(0);
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+
+ request_json = R"({})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-long"})");
+ return WasmResult::Ok;
+ });
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+}
+
+TEST_F(ModelMapperTest, RouteLevelModelMappingTest) {
+ std::string configuration = R"(
+{
+ "_rules_": [
+ {
+ "_match_route_": ["route-a"],
+ "_match_service_": ["service-1"],
+ "modelMapping": {
+ "*": "qwen-long"
+ }
+ },
+ {
+ "_match_route_": ["route-b"],
+ "_match_service_": ["service-2"],
+ "modelMapping": {
+ "*": "qwen-max"
+ }
+ },
+ {
+ "_match_route_": ["route-b"],
+ "_match_service_": ["service-3"],
+ "modelMapping": {
+ "*": "qwen-turbo"
+ }
+ }
+]})";
+
+ config_.set(configuration);
+ EXPECT_TRUE(root_context_->configure(configuration.size()));
+
+ path_ = "/api/v1/chat/completions";
+ std::string request_json = R"({"model": "gpt-4"})";
+ body_.set(request_json);
+ route_name_ = "route-a";
+ service_name_ = "outbound|80||service-1";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-long"})");
+ return WasmResult::Ok;
+ });
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+
+ route_name_ = "route-b";
+ service_name_ = "outbound|80||service-2";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-max"})");
+ return WasmResult::Ok;
+ });
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+
+ route_name_ = "route-b";
+ service_name_ = "outbound|80||service-3";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
+ EXPECT_EQ(body, R"({"model":"qwen-turbo"})");
+ return WasmResult::Ok;
+ });
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
+}
+
+} // namespace model_mapper
+} // namespace null_plugin
+} // namespace proxy_wasm
diff --git a/plugins/wasm-cpp/extensions/model_router/README.md b/plugins/wasm-cpp/extensions/model_router/README.md
index b63be35d8f..e78988e0e0 100644
--- a/plugins/wasm-cpp/extensions/model_router/README.md
+++ b/plugins/wasm-cpp/extensions/model_router/README.md
@@ -1,33 +1,67 @@
## 功能说明
`model-router`插件实现了基于LLM协议中的model参数路由的功能
+## 配置字段
+
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
+| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
+| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 |
+| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 |
+| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 |
+
## 运行属性
-插件执行阶段:`默认阶段`
-插件执行优先级:`260`
+插件执行阶段:认证阶段
+插件执行优先级:900
-## 配置字段
+## 效果说明
-| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
-| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
-| `enable` | bool | 选填 | false | 是否开启基于model参数路由 |
-| `model_key` | string | 选填 | model | 请求body中model参数的位置 |
-| `add_header_key` | string | 选填 | x-higress-llm-provider | 从model参数中解析出的provider名字放到哪个请求header中 |
+### 基于 model 参数进行路由
+需要做如下配置:
-## 效果说明
+```yaml
+modelToHeader: x-higress-llm-model
+```
+
+插件会将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是:
+
+```json
+{
+ "model": "qwen-long",
+ "frequency_penalty": 0,
+ "max_tokens": 800,
+ "stream": false,
+ "messages": [{
+ "role": "user",
+ "content": "higress项目主仓库的github地址是什么"
+ }],
+ "presence_penalty": 0,
+ "temperature": 0.7,
+ "top_p": 0.95
+}
+```
+
+经过这个插件后,将添加下面这个请求头(可以用于路由匹配):
+
+x-higress-llm-model: qwen-long
+
+### 提取 model 参数中的 provider 字段用于路由
+
+> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider
-如下开启基于model参数路由的功能:
+需要做如下配置:
```yaml
-enable: true
+addProviderHeader: x-higress-llm-provider
```
-开启后,插件将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
+插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
```json
{
- "model": "qwen/qwen-long",
+ "model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
@@ -43,7 +77,7 @@ enable: true
经过这个插件后,将添加下面这个请求头(可以用于路由匹配):
-x-higress-llm-provider: qwen
+x-higress-llm-provider: dashscope
原始的 LLM 请求体将被改成:
diff --git a/plugins/wasm-cpp/extensions/model_router/README_EN.md b/plugins/wasm-cpp/extensions/model_router/README_EN.md
index 4d2eaf1fee..f5866423b3 100644
--- a/plugins/wasm-cpp/extensions/model_router/README_EN.md
+++ b/plugins/wasm-cpp/extensions/model_router/README_EN.md
@@ -1,38 +1,41 @@
## Function Description
-The `model-router` plugin implements the functionality of routing based on the `model` parameter in the LLM protocol.
+The `model-router` plugin implements the function of routing based on the model parameter in the LLM protocol.
-## Runtime Properties
+## Configuration Fields
-Plugin Execution Phase: `Default Phase`
-Plugin Execution Priority: `260`
+| Name | Data Type | Filling Requirement | Default Value | Description |
+| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
+| `modelKey` | string | Optional | model | The location of the model parameter in the request body |
+| `addProviderHeader` | string | Optional | - | Which request header to place the provider name parsed from the model parameter |
+| `modelToHeader` | string | Optional | - | Which request header to directly place the model parameter |
+| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only effective for requests with these specific path suffixes |
-## Configuration Fields
+## Runtime Attributes
-| Name | Data Type | Filling Requirement | Default Value | Description |
-| -------------------- | ------------- | --------------------- | ---------------------- | ----------------------------------------------------- |
-| `enable` | bool | Optional | false | Whether to enable routing based on the `model` parameter |
-| `model_key` | string | Optional | model | The location of the `model` parameter in the request body |
-| `add_header_key` | string | Optional | x-higress-llm-provider | The header where the parsed provider name from the `model` parameter will be placed |
+Plugin execution phase: Authentication phase
+Plugin execution priority: 900
## Effect Description
-To enable routing based on the `model` parameter, use the following configuration:
+### Routing Based on the model Parameter
+
+The following configuration is required:
```yaml
-enable: true
+modelToHeader: x-higress-llm-model
```
-After enabling, the plugin extracts the provider part (if any) from the `model` parameter in the request, and sets it in the `x-higress-llm-provider` request header for subsequent routing. It also rewrites the `model` parameter to the model name part. For example, the original LLM request body is:
+The plugin will extract the model parameter from the request and set it in the x-higress-llm-model request header, which can be used for subsequent routing. For example, the original LLM request body:
```json
{
- "model": "openai/gpt-4o",
+ "model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
- "content": "What is the GitHub address for the main repository of the Higress project?"
+ "content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
@@ -40,24 +43,55 @@ After enabling, the plugin extracts the provider part (if any) from the `model`
}
```
-After processing by the plugin, the following request header (which can be used for routing matching) will be added:
+After processing by this plugin, the following request header (which can be used for route matching) will be added:
+
+x-higress-llm-model: qwen-long
-`x-higress-llm-provider: openai`
+### Extracting the provider Field from the model Parameter for Routing
-The original LLM request body will be modified to:
+> Note that this mode requires the client to specify the provider using a `/` separator in the model parameter.
+
+The following configuration is required:
+
+```yaml
+addProviderHeader: x-higress-llm-provider
+```
+
+The plugin will extract the provider part (if present) from the model parameter in the request and set it in the x-higress-llm-provider request header, which can be used for subsequent routing, and rewrite the model parameter to the model name part. For example, the original LLM request body:
```json
{
- "model": "gpt-4o",
+ "model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
- "content": "What is the GitHub address for the main repository of the Higress project?"
+ "content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
+
+After processing by this plugin, the following request header (which can be used for route matching) will be added:
+
+x-higress-llm-provider: dashscope
+
+The original LLM request body will be changed to:
+
+```json
+{
+ "model": "qwen-long",
+ "frequency_penalty": 0,
+ "max_tokens": 800,
+ "stream": false,
+ "messages": [{
+ "role": "user",
+ "content": "What is the GitHub address of the main repository for the higress project"
+ }],
+ "presence_penalty": 0,
+ "temperature": 0.7,
+ "top_p": 0.95
+}
diff --git a/plugins/wasm-cpp/extensions/model_router/plugin.cc b/plugins/wasm-cpp/extensions/model_router/plugin.cc
index 66a90973ff..0e45ec473f 100644
--- a/plugins/wasm-cpp/extensions/model_router/plugin.cc
+++ b/plugins/wasm-cpp/extensions/model_router/plugin.cc
@@ -51,41 +51,54 @@ constexpr std::string_view DefaultMaxBodyBytes = "10485760";
bool PluginRootContext::parsePluginConfig(const json& configuration,
ModelRouterConfigRule& rule) {
- if (auto it = configuration.find("enable"); it != configuration.end()) {
- if (it->is_boolean()) {
- rule.enable_ = it->get();
+ if (auto it = configuration.find("modelKey"); it != configuration.end()) {
+ if (it->is_string()) {
+ rule.model_key_ = it->get();
} else {
- LOG_WARN("Invalid type for enable. Expected boolean.");
+ LOG_ERROR("Invalid type for modelKey. Expected string.");
return false;
}
}
- if (auto it = configuration.find("model_key"); it != configuration.end()) {
+ if (auto it = configuration.find("addProviderHeader");
+ it != configuration.end()) {
if (it->is_string()) {
- rule.model_key_ = it->get();
+ rule.add_provider_header_ = it->get();
} else {
- LOG_WARN("Invalid type for model_key. Expected string.");
+ LOG_ERROR("Invalid type for addProviderHeader. Expected string.");
return false;
}
}
- if (auto it = configuration.find("add_header_key");
+ if (auto it = configuration.find("modelToHeader");
it != configuration.end()) {
if (it->is_string()) {
- rule.add_header_key_ = it->get();
+ rule.model_to_header_ = it->get();
} else {
- LOG_WARN("Invalid type for add_header_key. Expected string.");
+ LOG_ERROR("Invalid type for modelToHeader. Expected string.");
return false;
}
}
+ if (!JsonArrayIterate(
+ configuration, "enableOnPathSuffix", [&](const json& item) -> bool {
+ if (item.is_string()) {
+ rule.enable_on_path_suffix_.emplace_back(item.get());
+ return true;
+ }
+ return false;
+ })) {
+ LOG_ERROR("Invalid type for item in enableOnPathSuffix. Expected string.");
+ return false;
+ }
+
return true;
}
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
- LOG_WARN("configuration has errors initialization will not continue.");
+ LOG_ERROR("configuration has errors initialization will not continue.");
return false;
}
return true;
@@ -97,13 +110,13 @@ bool PluginRootContext::configure(size_t configuration_size) {
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result) {
- LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
- configuration_data->view()));
+ LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
+ configuration_data->view()));
return false;
}
if (!parseRuleConfig(result.value())) {
- LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
- configuration_data->view()));
+ LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
+ configuration_data->view()));
return false;
}
return true;
@@ -111,7 +124,25 @@ bool PluginRootContext::configure(size_t configuration_size) {
FilterHeadersStatus PluginRootContext::onHeader(
const ModelRouterConfigRule& rule) {
- if (!rule.enable_ || !Wasm::Common::Http::hasRequestBody()) {
+ if (!Wasm::Common::Http::hasRequestBody()) {
+ return FilterHeadersStatus::Continue;
+ }
+ auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString();
+ auto params_pos = path.find('?');
+ size_t uri_end;
+ if (params_pos == std::string::npos) {
+ uri_end = path.size();
+ } else {
+ uri_end = params_pos;
+ }
+ bool enable = false;
+ for (const auto& enable_suffix : rule.enable_on_path_suffix_) {
+ if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) {
+ enable = true;
+ break;
+ }
+ }
+ if (!enable) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
@@ -128,7 +159,8 @@ FilterHeadersStatus PluginRootContext::onHeader(
FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
std::string_view body) {
const auto& model_key = rule.model_key_;
- const auto& add_header_key = rule.add_header_key_;
+ const auto& add_provider_header = rule.add_provider_header_;
+ const auto& model_to_header = rule.model_to_header_;
auto body_json_opt = ::Wasm::Common::JsonParse(body);
if (!body_json_opt) {
LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
@@ -137,18 +169,24 @@ FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
auto body_json = body_json_opt.value();
if (body_json.contains(model_key)) {
std::string model_value = body_json[model_key];
- auto pos = model_value.find('/');
- if (pos != std::string::npos) {
- const auto& provider = model_value.substr(0, pos);
- const auto& model = model_value.substr(pos + 1);
- replaceRequestHeader(add_header_key, provider);
- body_json[model_key] = model;
- setBuffer(WasmBufferType::HttpRequestBody, 0,
- std::numeric_limits::max(), body_json.dump());
- LOG_DEBUG(absl::StrCat("model route to provider:", provider,
- ", model:", model));
- } else {
- LOG_DEBUG(absl::StrCat("model route not work, model:", model_value));
+ if (!model_to_header.empty()) {
+ replaceRequestHeader(model_to_header, model_value);
+ }
+ if (!add_provider_header.empty()) {
+ auto pos = model_value.find('/');
+ if (pos != std::string::npos) {
+ const auto& provider = model_value.substr(0, pos);
+ const auto& model = model_value.substr(pos + 1);
+ replaceRequestHeader(add_provider_header, provider);
+ body_json[model_key] = model;
+ setBuffer(WasmBufferType::HttpRequestBody, 0,
+ std::numeric_limits::max(), body_json.dump());
+ LOG_DEBUG(absl::StrCat("model route to provider:", provider,
+ ", model:", model));
+ } else {
+ LOG_DEBUG(absl::StrCat("model route to provider not work, model:",
+ model_value));
+ }
}
}
return FilterDataStatus::Continue;
diff --git a/plugins/wasm-cpp/extensions/model_router/plugin.h b/plugins/wasm-cpp/extensions/model_router/plugin.h
index 16cfdf8509..14ef0b632a 100644
--- a/plugins/wasm-cpp/extensions/model_router/plugin.h
+++ b/plugins/wasm-cpp/extensions/model_router/plugin.h
@@ -37,9 +37,10 @@ namespace model_router {
#endif
struct ModelRouterConfigRule {
- bool enable_ = false;
std::string model_key_ = "model";
- std::string add_header_key_ = "x-higress-llm-provider";
+ std::string add_provider_header_;
+ std::string model_to_header_;
+ std::vector enable_on_path_suffix_ = {"/v1/chat/completions"};
};
// PluginRootContext is the root context for all streams processed by the
diff --git a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc
index dc351ecdc8..1204e16b31 100644
--- a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc
+++ b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc
@@ -89,6 +89,8 @@ class ModelRouterTest : public ::testing::Test {
*result = "application/json";
} else if (header == "content-length") {
*result = "1024";
+ } else if (header == ":path") {
+ *result = path_;
}
return WasmResult::Ok;
});
@@ -122,6 +124,7 @@ class ModelRouterTest : public ::testing::Test {
std::unique_ptr root_context_;
std::unique_ptr context_;
std::string route_name_;
+ std::string path_;
BufferBase body_;
BufferBase config_;
};
@@ -129,12 +132,13 @@ class ModelRouterTest : public ::testing::Test {
TEST_F(ModelRouterTest, RewriteModelAndHeader) {
std::string configuration = R"(
{
- "enable": true
+ "addProviderHeader": "x-higress-llm-provider"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
+ path_ = "/v1/chat/completions";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
@@ -154,19 +158,73 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
+TEST_F(ModelRouterTest, ModelToHeader) {
+ std::string configuration = R"(
+{
+ "modelToHeader": "x-higress-llm-model"
+ })";
+
+ config_.set(configuration);
+ EXPECT_TRUE(root_context_->configure(configuration.size()));
+
+ path_ = "/v1/chat/completions";
+ std::string request_json = R"({"model": "qwen-long"})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .Times(0);
+
+ EXPECT_CALL(
+ *mock_context_,
+ replaceHeaderMapValue(testing::_, std::string_view("x-higress-llm-model"),
+ std::string_view("qwen-long")));
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::StopIteration);
+ EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
+}
+
+TEST_F(ModelRouterTest, IgnorePath) {
+ std::string configuration = R"(
+{
+ "addProviderHeader": "x-higress-llm-provider"
+ })";
+
+ config_.set(configuration);
+ EXPECT_TRUE(root_context_->configure(configuration.size()));
+
+ path_ = "/v1/chat/xxxx";
+ std::string request_json = R"({"model": "qwen/qwen-long"})";
+ EXPECT_CALL(*mock_context_,
+ setBuffer(testing::_, testing::_, testing::_, testing::_))
+ .Times(0);
+
+ EXPECT_CALL(*mock_context_,
+ replaceHeaderMapValue(testing::_,
+ std::string_view("x-higress-llm-provider"),
+ std::string_view("qwen")))
+ .Times(0);
+
+ body_.set(request_json);
+ EXPECT_EQ(context_->onRequestHeaders(0, false),
+ FilterHeadersStatus::Continue);
+ EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
+}
+
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
std::string configuration = R"(
{
"_rules_": [
{
"_match_route_": ["route-a"],
- "enable": true
+ "addProviderHeader": "x-higress-llm-provider"
}
]})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
+ path_ = "/api/v1/chat/completions";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))