Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use nicer Candle Error APIs #767

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl DeviceMapper for LayerDeviceMapper {
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
dtype
.try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
.map_err(|e| candle_core::Error::Msg(format!("{e:?}")))
.map_err(candle_core::Error::msg)
}
}

Expand Down Expand Up @@ -249,6 +249,6 @@ impl DeviceMapper for DummyDeviceMapper {
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
dtype
.try_into_dtype(&[&self.nm_device])
.map_err(|e| candle_core::Error::Msg(format!("{e:?}")))
.map_err(candle_core::Error::msg)
}
}
2 changes: 1 addition & 1 deletion mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ impl Engine {
let prompt = get_mut_arcmutex!(self.pipeline)
.tokenizer()
.encode(text, true)
.map_err(|e| anyhow::Error::msg(e.to_string()));
.map_err(anyhow::Error::msg);
handle_seq_error!(prompt, request.response)
.get_ids()
.to_vec()
Expand Down
15 changes: 6 additions & 9 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
0.0,
vec![],
)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;

let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
1, false, false, 0,
Expand Down Expand Up @@ -402,7 +402,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
true,
Vec::new(),
)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let images = image_urls.as_ref().map(|urls| {
urls.iter()
.map(|url| -> anyhow::Result<DynamicImage> {
Expand Down Expand Up @@ -511,26 +511,23 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
}

let mut writer =
csv::Writer::from_path(path).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;

let mut header = vec![format!("Step")];
header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
writer
.write_record(&header)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;

for (i, row) in all_losses.into_iter().enumerate() {
let mut new_row = vec![format!("Step {i}")];
new_row.extend(row.iter().map(|x| format!("{x:.4}")));
writer
.write_record(&new_row)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}

writer
.flush()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
writer.flush().map_err(candle_core::Error::msg)?;
}

Ok(Some(AnyMoeTrainingResult {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/isq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
#[serde(rename = "default")]
Default,
/// Only quantize MoE experts, if applicable. The enables MoQE.
/// https://arxiv.org/abs/2310.02410

Check warning on line 86 in mistralrs-core/src/pipeline/isq.rs

View workflow job for this annotation

GitHub Actions / Docs

this URL is not a hyperlink
#[serde(rename = "moqe")]
MoeExpertsOnly,
}
Expand Down Expand Up @@ -225,7 +225,7 @@
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(minimum_max_threads)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;

pool.install(|| {
use indicatif::ParallelProgressIterator;
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ pub trait Pipeline:
let InputProcessorOutput {
inputs,
seq_indices,
} = inputs.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
} = inputs.map_err(candle_core::Error::msg)?;
if i == 0 {
match pre_op {
CacheInstruction::In(ref adapter_inst) => {
Expand Down Expand Up @@ -404,7 +404,7 @@ pub trait Pipeline:
let InputProcessorOutput {
inputs,
seq_indices,
} = inputs.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
} = inputs.map_err(candle_core::Error::msg)?;

let raw_logits = self.forward_inputs(inputs)?;

Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,16 +602,16 @@ impl AnyMoePipelineMixin for NormalPipeline {
) -> candle_core::Result<()> {
let mut vbs = Vec::new();
// Precompile regex here
let regex = Regex::new(match_regex).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
for model_id in model_ids {
let model_id_str = &model_id;
let model_id = Path::new(&model_id);

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down Expand Up @@ -651,9 +651,9 @@ impl AnyMoePipelineMixin for NormalPipeline {

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub trait Processor {
let encoding = pipeline
.tokenizer()
.encode(prompt, true)
.map_err(|e| anyhow::Error::msg(e.to_string()))?;
.map_err(anyhow::Error::msg)?;
Ok(encoding.get_ids().to_vec())
}
fn inputs_processor(&self) -> Arc<dyn InputsProcessor>;
Expand Down
8 changes: 3 additions & 5 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
let mut tool_calls = Vec::new();
let mut text_new = Some(text.clone());
if let Some(ref matcher) = seq.tools {
let calls = matcher
.get_call(&text)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let calls = matcher.get_call(&text).map_err(candle_core::Error::msg)?;
if !calls.is_empty() {
text_new = None;
}
Expand Down Expand Up @@ -358,12 +356,12 @@ pub async fn sample_sequence(
SequenceRecognizer::Regex(ref mut rx) => {
seq.tok_trie
.append_token(rx.as_mut(), second_logprobs_response.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::Cfg(ref mut cfg) => {
seq.tok_trie
.append_token(cfg.as_mut(), second_logprobs_response.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::None => {}
}
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,14 +570,14 @@ impl Pipeline for SpeculativePipeline {
.get_metadata()
.tok_trie
.append_token(rx.as_mut(), accepted.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::Cfg(ref mut cfg) => {
get_mut_arcmutex!(self.target)
.get_metadata()
.tok_trie
.append_token(cfg.as_mut(), accepted.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::None => {}
}
Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,16 @@ impl AnyMoePipelineMixin for VisionPipeline {
) -> candle_core::Result<()> {
let mut vbs = Vec::new();
// Precompile regex here
let regex = Regex::new(match_regex).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
for model_id in model_ids {
let model_id_str = &model_id;
let model_id = Path::new(&model_id);

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down Expand Up @@ -534,9 +534,9 @@ impl AnyMoePipelineMixin for VisionPipeline {

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl Processor for Idefics2Processor {
let encoding = pipeline
.tokenizer()
.encode(prompt, true)
.map_err(|e| anyhow::Error::msg(e.to_string()))?;
.map_err(anyhow::Error::msg)?;
Ok(encoding.get_ids().to_vec())
}

Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/Cargo_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pyo3.workspace = true
mistralrs-core = { version = "0.3.0", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
serde.workspace = true
serde_json.workspace = true
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e", features=["$feature_name"] }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486", features=["$feature_name"] }
indexmap.workspace = true
accelerate-src = { workspace = true, optional = true }
intel-mkl-src = { workspace = true, optional = true }
Expand Down
Loading