Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow model rewrites #1009

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cargo-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
with:
contains_the_file: Cargo.toml
changed_relative_to_ref: origin/${{ github.base_ref || 'not-a-branch' }}
ignore_dirs: ".coredb examples tembo-cli/temboclient tembo-cli/tembodataclient"
ignore_dirs: ".coredb examples tembo-cli/temboclient tembo-cli/tembodataclient inference-gateway"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this project has its own CLI file, tembo-ai.yml that run tests and clippy etc with its own configurations.


lint:
name: Run linters
Expand Down
4 changes: 2 additions & 2 deletions inference-gateway/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ run-mock-server:
docker compose up -d mock-server

unit-test:
cargo test
cargo test -- --test-threads=1

integration-test: run-mock-server
RUST_LOG=${RUST_LOG} MODEL_SERVICE_PORT_MAP=${MODEL_SERVICE_PORT_MAP} cargo test ${TEST_NAME} -- --ignored --nocapture
RUST_LOG=${RUST_LOG} MODEL_SERVICE_PORT_MAP=${MODEL_SERVICE_PORT_MAP} cargo test ${TEST_NAME} -- --ignored --nocapture --test-threads=1

test-all: unit-test integration-test
124 changes: 124 additions & 0 deletions inference-gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ use std::env;

use url::Url;

use crate::errors::PlatformError;

#[derive(Clone, Debug)]
pub struct Config {
pub model_rewrites: HashMap<String, String>,
pub model_service_map: HashMap<String, Url>,
/// Postgres connection string to the timeseries database which logs token usage
pub pg_conn_str: String,
Expand All @@ -25,6 +28,7 @@ pub struct Config {
impl Config {
pub async fn new() -> Self {
Self {
model_rewrites: parse_model_rewrite(),
model_service_map: parse_model_service_port_map(),
pg_conn_str: from_env_default(
"DATABASE_URL",
Expand Down Expand Up @@ -86,11 +90,131 @@ fn parse_model_service_port_map() -> HashMap<String, Url> {
model_map
}

fn parse_model_rewrite() -> HashMap<String, String> {
let mut map = HashMap::new();

if let Ok(env_var) = env::var("MODEL_REWRITES") {
for pair in env_var.split(',') {
if let Some((key, value)) = pair.split_once(':') {
map.insert(key.to_string(), value.to_string());
}
}
}

map
}

#[derive(Debug)]
pub struct MappedRequest {
// the mapped model name
pub model: String,
// url to the correct service for the model
pub base_url: Url,
// request body with updated model name
pub body: serde_json::Value,
}

pub fn rewrite_model_request(
mut body: serde_json::Value,
config: &Config,
) -> Result<MappedRequest, PlatformError> {
// map the model, if there is a mapping for it
let target_model = if let Some(model) = body.get("model") {
let requested_model = model.as_str().ok_or_else(|| {
PlatformError::InvalidQuery("empty value in `model` parameter".to_string())
})?;

if let Some(rewritten_model) = config.model_rewrites.get(requested_model) {
body["model"] = serde_json::Value::String(rewritten_model.clone());
rewritten_model
} else {
requested_model
}
} else {
Err(PlatformError::InvalidQuery(
"missing `model` parameter in request body".to_string(),
))?
};

let base_url = config
.model_service_map
.get(target_model)
.ok_or_else(|| PlatformError::InvalidQuery(format!("model {} not found", target_model)))?
.clone();

Ok(MappedRequest {
model: target_model.to_string(),
base_url,
body,
})
}

#[cfg(test)]
mod tests {
use super::*;
use std::env;

#[tokio::test]
async fn test_rewrite() {
env::set_var("MODEL_REWRITES", "cat:dog,old:young");
env::set_var(
"MODEL_SERVICE_PORT_MAP",
"dog=http://dog:8000/,young=http://young:8000/",
);

let cfg = Config::new().await;
let body = serde_json::json!({
"model": "cat",
"key": "value"
});

let rewritten = rewrite_model_request(body.clone(), &cfg).unwrap();
assert_eq!(rewritten.model, "dog");
assert_eq!(rewritten.base_url.to_string(), "http://dog:8000/");
assert_eq!(rewritten.body.get("key").unwrap(), "value");

let body = serde_json::json!({
"model": "old",
"key": "value2"
});

let rewritten = rewrite_model_request(body.clone(), &cfg).unwrap();
assert_eq!(rewritten.model, "young");
assert_eq!(rewritten.base_url.to_string(), "http://young:8000/");
assert_eq!(rewritten.body.get("key").unwrap(), "value2");
}

#[test]
fn test_valid_env_var() {
env::set_var("MODEL_REWRITES", "cat:dog,old:young");
let result = parse_model_rewrite();

let mut expected = HashMap::new();
expected.insert("cat".to_string(), "dog".to_string());
expected.insert("old".to_string(), "young".to_string());

assert_eq!(result, expected);
}

#[test]
fn test_empty_env_var() {
env::set_var("MODEL_REWRITES", "");
let result = parse_model_rewrite();
assert!(result.is_empty());
}

#[test]
fn test_invalid_format() {
env::set_var("MODEL_REWRITES", "cat:dog,invalidpair,old:young");
let result = parse_model_rewrite();

let mut expected = HashMap::new();
expected.insert("cat".to_string(), "dog".to_string());
expected.insert("old".to_string(), "young".to_string());

assert_eq!(result, expected);
}

#[test]
fn test_default_values() {
env::remove_var("MODEL_SERVICE_PORT_MAP");
Expand Down
14 changes: 3 additions & 11 deletions inference-gateway/src/routes/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;

use crate::authorization;
use crate::config::rewrite_model_request;
use crate::errors::{AuthError, PlatformError};

pub async fn forward_request(
Expand Down Expand Up @@ -45,18 +46,9 @@ pub async fn forward_request(
return Ok(HttpResponse::BadRequest().body("Embedding generation is not yet supported"));
}

let requested_model = body
.get("model")
.ok_or_else(|| PlatformError::InvalidQuery("missing model in request body".to_string()))?
.as_str()
.ok_or_else(|| PlatformError::InvalidQuery("missing model in request body".to_string()))?;
let rewrite_request = rewrite_model_request(body.clone(), &config)?;

let mut new_url = config
.model_service_map
.get(requested_model)
.ok_or_else(|| PlatformError::InvalidQuery(format!("model {} not found", requested_model)))?
.clone();
// let mut new_url = config.llm_service_host_port.clone();
let mut new_url = rewrite_request.base_url;
new_url.set_path(path);
new_url.set_query(req.uri().query());

Expand Down
Loading