Skip to content

Commit

Permalink
Improve Auto dtype determination (#438)
Browse files Browse the repository at this point in the history
* Improve auto dtype determination

* Clippy
  • Loading branch information
EricLBuehler authored Jun 15, 2024
1 parent d66a8f4 commit b9e9eca
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions mistralrs-core/src/utils/normal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{fmt::Display, str::FromStr};

use anyhow::Result;
use candle_core::{DType, Device};
use candle_core::{DType, Device, Tensor};
use serde::Deserialize;
use tracing::info;

Expand All @@ -10,9 +10,7 @@ use tracing::info;
///
/// If the model is quantized, this is ignored so it is reasonable to use the [`Default`] impl.
///
/// ## `Auto` rules
/// - If CUDA device or CPU, use BF16
/// - Fallback to F16
/// Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> 32
pub enum ModelDType {
#[default]
#[serde(rename = "auto")]
Expand Down Expand Up @@ -64,16 +62,26 @@ impl TryIntoDType for DType {
}
}

fn determine_auto_dtype(device: &Device) -> candle_core::Result<DType> {
for dtype in [DType::BF16, DType::F16] {
// Try a matmul
let x = Tensor::zeros((2, 2), dtype, device)?;
let y = x.matmul(&x);
match y {
Ok(_) => return Ok(dtype),
Err(e) => match e {
candle_core::Error::UnsupportedDTypeForOp(_, _) => continue,
other => return Err(other),
},
}
}
Ok(DType::F32)
}

impl TryIntoDType for ModelDType {
fn try_into_dtype(&self, device: &Device) -> Result<DType> {
let dtype = match self {
Self::Auto => {
if device.is_cuda() || device.is_cpu() {
Ok(DType::BF16)
} else {
Ok(DType::F32)
}
}
Self::Auto => Ok(determine_auto_dtype(device).map_err(anyhow::Error::msg)?),
Self::BF16 => Ok(DType::BF16),
Self::F16 => Ok(DType::F16),
Self::F32 => Ok(DType::F32),
Expand Down

0 comments on commit b9e9eca

Please sign in to comment.