Skip to content

Commit

Permalink
[Developer QoL]: Use nicer Candle Error APIs (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Sep 11, 2024
1 parent f85ddcd commit 1cdec83
Show file tree
Hide file tree
Showing 15 changed files with 38 additions and 43 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl DeviceMapper for LayerDeviceMapper {
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
dtype
.try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
.map_err(|e| candle_core::Error::Msg(format!("{e:?}")))
.map_err(candle_core::Error::msg)
}
}

Expand Down Expand Up @@ -249,6 +249,6 @@ impl DeviceMapper for DummyDeviceMapper {
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
dtype
.try_into_dtype(&[&self.nm_device])
.map_err(|e| candle_core::Error::Msg(format!("{e:?}")))
.map_err(candle_core::Error::msg)
}
}
2 changes: 1 addition & 1 deletion mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ impl Engine {
let prompt = get_mut_arcmutex!(self.pipeline)
.tokenizer()
.encode(text, true)
.map_err(|e| anyhow::Error::msg(e.to_string()));
.map_err(anyhow::Error::msg);
handle_seq_error!(prompt, request.response)
.get_ids()
.to_vec()
Expand Down
15 changes: 6 additions & 9 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
0.0,
vec![],
)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;

let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
1, false, false, 0,
Expand Down Expand Up @@ -402,7 +402,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
true,
Vec::new(),
)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let images = image_urls.as_ref().map(|urls| {
urls.iter()
.map(|url| -> anyhow::Result<DynamicImage> {
Expand Down Expand Up @@ -511,26 +511,23 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
}

let mut writer =
csv::Writer::from_path(path).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;

let mut header = vec![format!("Step")];
header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
writer
.write_record(&header)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;

for (i, row) in all_losses.into_iter().enumerate() {
let mut new_row = vec![format!("Step {i}")];
new_row.extend(row.iter().map(|x| format!("{x:.4}")));
writer
.write_record(&new_row)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}

writer
.flush()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
writer.flush().map_err(candle_core::Error::msg)?;
}

Ok(Some(AnyMoeTrainingResult {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/isq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ pub trait IsqModel {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(minimum_max_threads)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;

pool.install(|| {
use indicatif::ParallelProgressIterator;
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ pub trait Pipeline:
let InputProcessorOutput {
inputs,
seq_indices,
} = inputs.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
} = inputs.map_err(candle_core::Error::msg)?;
if i == 0 {
match pre_op {
CacheInstruction::In(ref adapter_inst) => {
Expand Down Expand Up @@ -404,7 +404,7 @@ pub trait Pipeline:
let InputProcessorOutput {
inputs,
seq_indices,
} = inputs.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
} = inputs.map_err(candle_core::Error::msg)?;

let raw_logits = self.forward_inputs(inputs)?;

Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,16 +602,16 @@ impl AnyMoePipelineMixin for NormalPipeline {
) -> candle_core::Result<()> {
let mut vbs = Vec::new();
// Precompile regex here
let regex = Regex::new(match_regex).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
for model_id in model_ids {
let model_id_str = &model_id;
let model_id = Path::new(&model_id);

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down Expand Up @@ -651,9 +651,9 @@ impl AnyMoePipelineMixin for NormalPipeline {

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub trait Processor {
let encoding = pipeline
.tokenizer()
.encode(prompt, true)
.map_err(|e| anyhow::Error::msg(e.to_string()))?;
.map_err(anyhow::Error::msg)?;
Ok(encoding.get_ids().to_vec())
}
fn inputs_processor(&self) -> Arc<dyn InputsProcessor>;
Expand Down
8 changes: 3 additions & 5 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
let mut tool_calls = Vec::new();
let mut text_new = Some(text.clone());
if let Some(ref matcher) = seq.tools {
let calls = matcher
.get_call(&text)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let calls = matcher.get_call(&text).map_err(candle_core::Error::msg)?;
if !calls.is_empty() {
text_new = None;
}
Expand Down Expand Up @@ -358,12 +356,12 @@ pub async fn sample_sequence(
SequenceRecognizer::Regex(ref mut rx) => {
seq.tok_trie
.append_token(rx.as_mut(), second_logprobs_response.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::Cfg(ref mut cfg) => {
seq.tok_trie
.append_token(cfg.as_mut(), second_logprobs_response.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::None => {}
}
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,14 +570,14 @@ impl Pipeline for SpeculativePipeline {
.get_metadata()
.tok_trie
.append_token(rx.as_mut(), accepted.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::Cfg(ref mut cfg) => {
get_mut_arcmutex!(self.target)
.get_metadata()
.tok_trie
.append_token(cfg.as_mut(), accepted.token)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::None => {}
}
Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,16 @@ impl AnyMoePipelineMixin for VisionPipeline {
) -> candle_core::Result<()> {
let mut vbs = Vec::new();
// Precompile regex here
let regex = Regex::new(match_regex).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
for model_id in model_ids {
let model_id_str = &model_id;
let model_id = Path::new(&model_id);

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down Expand Up @@ -534,9 +534,9 @@ impl AnyMoePipelineMixin for VisionPipeline {

let api = ApiBuilder::new()
.with_progress(!silent)
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
.build()
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
.map_err(candle_core::Error::msg)?;
let revision = revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(
model_id_str.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl Processor for Idefics2Processor {
let encoding = pipeline
.tokenizer()
.encode(prompt, true)
.map_err(|e| anyhow::Error::msg(e.to_string()))?;
.map_err(anyhow::Error::msg)?;
Ok(encoding.get_ids().to_vec())
}

Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/Cargo_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pyo3.workspace = true
mistralrs-core = { version = "0.3.0", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
serde.workspace = true
serde_json.workspace = true
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e", features=["$feature_name"] }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486", features=["$feature_name"] }
indexmap.workspace = true
accelerate-src = { workspace = true, optional = true }
intel-mkl-src = { workspace = true, optional = true }
Expand Down

0 comments on commit 1cdec83

Please sign in to comment.