Skip to content

Commit

Permalink
Remove test and add custom error type to Python API (#738)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Sep 1, 2024
1 parent b84e836 commit 8725ce9
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 162 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mistralrs-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ reqwest.workspace = true
base64.workspace = true
url.workspace = true
data-url.workspace = true
anyhow.workspace = true

[build-dependencies]
pyo3-build-config = "0.22"
Expand Down
1 change: 1 addition & 0 deletions mistralrs-pyo3/Cargo_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ reqwest.workspace = true
base64.workspace = true
url.workspace = true
data-url.workspace = true
anyhow.workspace = true

[build-dependencies]
pyo3-build-config = "0.22"
109 changes: 45 additions & 64 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::{
};
use stream::ChatCompletionStreamer;
use tokio::sync::mpsc::channel;
use util::{PyApiErr, PyApiResult};

use candle_core::Device;
use mistralrs_core::{
Expand All @@ -25,7 +26,7 @@ use mistralrs_core::{
SchedulerConfig, SpeculativeConfig, SpeculativeLoader, StopTokens, TokenSource, Tool, Topology,
VisionLoaderBuilder, VisionSpecificConfig,
};
use pyo3::{exceptions::PyValueError, prelude::*};
use pyo3::prelude::*;
use std::fs::File;
mod anymoe;
mod requests;
Expand Down Expand Up @@ -62,7 +63,7 @@ fn parse_which(
no_kv_cache: bool,
chat_template: Option<String>,
prompt_batchsize: Option<NonZeroUsize>,
) -> PyResult<Box<dyn Loader>> {
) -> PyApiResult<Box<dyn Loader>> {
#[cfg(not(feature = "flash-attn"))]
let use_flash_attn = false;
#[cfg(feature = "flash-attn")]
Expand All @@ -84,8 +85,7 @@ fn parse_which(
tokenizer_json,
Some(model_id),
)
.build(arch.into())
.map_err(|e| PyValueError::new_err(e.to_string()))?,
.build(arch.into())?,
Which::XLora {
model_id,
xlora_model_id,
Expand All @@ -109,13 +109,11 @@ fn parse_which(
serde_json::from_reader(
File::open(order.clone())
.unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
)?,
no_kv_cache,
tgt_non_granular_index,
)
.build(arch.into())
.map_err(|e| PyValueError::new_err(e.to_string()))?,
.build(arch.into())?,
Which::Lora {
model_id,
tokenizer_json,
Expand All @@ -138,11 +136,9 @@ fn parse_which(
serde_json::from_reader(
File::open(order.clone())
.unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
)?,
)
.build(arch.into())
.map_err(|e| PyValueError::new_err(e.to_string()))?,
.build(arch.into())?,
Which::GGUF {
tok_model_id,
quantized_model_id,
Expand Down Expand Up @@ -182,8 +178,7 @@ fn parse_which(
serde_json::from_reader(
File::open(order.clone())
.unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
)?,
no_kv_cache,
tgt_non_granular_index,
)
Expand All @@ -210,8 +205,7 @@ fn parse_which(
serde_json::from_reader(
File::open(order.clone())
.unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
)?,
)
.build(),
Which::GGML {
Expand Down Expand Up @@ -261,8 +255,7 @@ fn parse_which(
serde_json::from_reader(
File::open(order.clone())
.unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
)?,
no_kv_cache,
tgt_non_granular_index,
)
Expand Down Expand Up @@ -293,8 +286,7 @@ fn parse_which(
serde_json::from_reader(
File::open(order.clone())
.unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
)?,
)
.build(),
Which::VisionPlain {
Expand Down Expand Up @@ -356,7 +348,7 @@ impl Runner {
pa_blk_size: Option<usize>,
no_paged_attn: bool,
prompt_batchsize: Option<usize>,
) -> PyResult<Self> {
) -> PyApiResult<Self> {
let tgt_non_granular_index = match which {
Which::Plain { .. }
| Which::Lora { .. }
Expand Down Expand Up @@ -386,7 +378,7 @@ impl Runner {

let prompt_batchsize = match prompt_batchsize {
Some(0) => {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"`prompt_batchsize` must be a strictly positive integer, got 0.",
))
}
Expand Down Expand Up @@ -432,7 +424,7 @@ impl Runner {

let device = get_device();
let isq = if let Some(isq) = in_situ_quant {
Some(parse_isq_value(&isq).map_err(|e| PyValueError::new_err(e.to_string()))?)
Some(parse_isq_value(&isq).map_err(PyApiErr::from)?)
} else {
None
};
Expand Down Expand Up @@ -518,16 +510,15 @@ impl Runner {
let pipeline = loader
.load_model_from_hf(
None,
TokenSource::from_str(token_source)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
TokenSource::from_str(token_source).map_err(PyApiErr::from)?,
&ModelDType::Auto,
&device,
true, // Silent for jupyter
mapper,
isq,
cache_config,
)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
.map_err(PyApiErr::from)?;

let scheduler_config = if cache_config.is_some() {
// Handle case where we may have device mapping
Expand All @@ -541,7 +532,7 @@ impl Runner {
method: DefaultSchedulerMethod::Fixed(
max_seqs
.try_into()
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?,
.map_err(|e| PyApiErr::from(format!("{e:?}")))?,
),
}
}
Expand All @@ -550,7 +541,7 @@ impl Runner {
method: DefaultSchedulerMethod::Fixed(
max_seqs
.try_into()
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?,
.map_err(|e| PyApiErr::from(format!("{e:?}")))?,
),
}
};
Expand All @@ -566,7 +557,7 @@ impl Runner {
fn send_chat_completion_request(
&mut self,
request: Py<ChatCompletionRequest>,
) -> PyResult<Either<ChatCompletionResponse, ChatCompletionStreamer>> {
) -> PyApiResult<Either<ChatCompletionResponse, ChatCompletionStreamer>> {
let (tx, mut rx) = channel(10_000);
Python::with_gil(|py| {
let request = request.bind(py).borrow();
Expand All @@ -576,20 +567,20 @@ impl Runner {
.map(|x| StopTokens::Seqs(x.to_vec()));
let constraint = if request.grammar_type == Some("regex".to_string()) {
if request.grammar.is_none() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Grammar type is specified but not grammar text",
));
}
Constraint::Regex(request.grammar.as_ref().unwrap().clone())
} else if request.grammar_type == Some("yacc".to_string()) {
if request.grammar.is_none() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Grammar type is specified but not grammar text",
));
}
Constraint::Yacc(request.grammar.as_ref().unwrap().clone())
} else if request.grammar_type.is_some() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Grammar type is specified but is not `regex` or `yacc`",
));
} else {
Expand Down Expand Up @@ -630,12 +621,12 @@ impl Runner {
}
Either::Right(image_messages) => {
if image_messages.len() != 2 {
return Err(PyValueError::new_err(
"Expected 2 items for the content of a message with an image."
.to_string()));
return Err(PyApiErr::from(
"Expected 2 items for the content of a message with an image."
));
}
if message["role"].as_ref().left().unwrap() != "user" {
return Err(PyValueError::new_err(format!(
return Err(PyApiErr::from(format!(
"Role for an image message must be `user`, but it is {}",
&message["role"].as_ref().left().unwrap()
)));
Expand All @@ -644,15 +635,15 @@ impl Runner {
let mut items = Vec::new();
for image_message in image_messages {
if image_message.len() != 2 {
return Err(PyValueError::new_err("Expected 2 items for the sub-content of a message with an image.".to_string()));
return Err(PyApiErr::from("Expected 2 items for the sub-content of a message with an image.".to_string()));
}
if !image_message.contains_key("type") {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Expected `type` key in input message.".to_string(),
));
}
if image_message["type"].is_right() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Expected string value in `type`.".to_string(),
));
}
Expand All @@ -667,9 +658,9 @@ impl Runner {
String,
Either<String, HashMap<String, String>>,
>],
) -> PyResult<(String, String)> {
) -> PyApiResult<(String, String)> {
if image_messages[text_idx]["text"].is_right() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Expected string value in `text`.".to_string(),
));
}
Expand All @@ -683,7 +674,7 @@ impl Runner {
.unwrap_right()
.contains_key("url")
{
return Err(PyValueError::new_err("Expected content of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}".to_string()));
return Err(PyApiErr::from("Expected content of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}".to_string()));
}
let url = image_messages[url_idx]["image_url"]
.as_ref()
Expand Down Expand Up @@ -758,10 +749,7 @@ impl Runner {
let tools = if let Some(tools) = &request.tool_schemas {
let mut new_tools = Vec::new();
for schema in tools {
new_tools.push(
serde_json::from_str::<Tool>(schema)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
);
new_tools.push(serde_json::from_str::<Tool>(schema)?);
}
Some(new_tools)
} else {
Expand Down Expand Up @@ -813,10 +801,10 @@ impl Runner {

match response {
Response::ValidationError(e) | Response::InternalError(e) => {
Err(PyValueError::new_err(e.to_string()))
Err(PyApiErr::from(e.to_string()))
}
Response::Done(response) => Ok(Either::Left(response)),
Response::ModelError(msg, _) => Err(PyValueError::new_err(msg.to_string())),
Response::ModelError(msg, _) => Err(PyApiErr::from(msg.to_string())),
Response::Chunk(_) => unreachable!(),
Response::CompletionDone(_) => unreachable!(),
Response::CompletionModelError(_, _) => unreachable!(),
Expand All @@ -830,7 +818,7 @@ impl Runner {
fn send_completion_request(
&mut self,
request: Py<CompletionRequest>,
) -> PyResult<CompletionResponse> {
) -> PyApiResult<CompletionResponse> {
let (tx, mut rx) = channel(10_000);
Python::with_gil(|py| {
let request = request.bind(py).borrow();
Expand All @@ -840,20 +828,20 @@ impl Runner {
.map(|x| StopTokens::Seqs(x.to_vec()));
let constraint = if request.grammar_type == Some("regex".to_string()) {
if request.grammar.is_none() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Grammar type is specified but not grammar text",
));
}
Constraint::Regex(request.grammar.as_ref().unwrap().clone())
} else if request.grammar_type == Some("yacc".to_string()) {
if request.grammar.is_none() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Grammar type is specified but not grammar text",
));
}
Constraint::Yacc(request.grammar.as_ref().unwrap().clone())
} else if request.grammar_type.is_some() {
return Err(PyValueError::new_err(
return Err(PyApiErr::from(
"Grammar type is specified but is not `regex` or `yacc`",
));
} else {
Expand All @@ -868,10 +856,7 @@ impl Runner {
let tools = if let Some(tools) = &request.tool_schemas {
let mut new_tools = Vec::new();
for schema in tools {
new_tools.push(
serde_json::from_str::<Tool>(schema)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
);
new_tools.push(serde_json::from_str::<Tool>(schema)?);
}
Some(new_tools)
} else {
Expand Down Expand Up @@ -934,12 +919,10 @@ impl Runner {

match response {
Response::ValidationError(e) | Response::InternalError(e) => {
Err(PyValueError::new_err(e.to_string()))
Err(PyApiErr::from(e.to_string()))
}
Response::CompletionDone(response) => Ok(response),
Response::CompletionModelError(msg, _) => {
Err(PyValueError::new_err(msg.to_string()))
}
Response::CompletionModelError(msg, _) => Err(PyApiErr::from(msg.to_string())),
Response::Chunk(_) => unreachable!(),
Response::Done(_) => unreachable!(),
Response::ModelError(_, _) => unreachable!(),
Expand All @@ -950,10 +933,8 @@ impl Runner {

/// Send a request to re-ISQ the model. If the model was loaded as GGUF or GGML
/// then nothing will happen.
fn send_re_isq(&self, dtype: String) -> PyResult<()> {
let request = _Request::ReIsq(
parse_isq_value(&dtype).map_err(|e| PyValueError::new_err(e.to_string()))?,
);
fn send_re_isq(&self, dtype: String) -> PyApiResult<()> {
let request = _Request::ReIsq(parse_isq_value(&dtype)?);
self.runner.get_sender()?.blocking_send(request).unwrap();
Ok(())
}
Expand Down
Loading

0 comments on commit 8725ce9

Please sign in to comment.