From a625394370e447f51e83aa7743cd76830dbce72b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:18:11 -0400 Subject: [PATCH] Nicer error when misconfigured pagedattn metadata (#753) --- mistralrs-core/src/pipeline/gguf.rs | 46 ++++++++++++--------------- mistralrs-core/src/pipeline/normal.rs | 16 ++++++---- mistralrs-core/src/pipeline/vision.rs | 16 ++++++---- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 44154b511..8ffc8350b 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -40,7 +40,7 @@ use crate::{ utils::tokens::get_token, xlora_models::{XLoraQLlama, XLoraQPhi3}, }; -use anyhow::{bail, Result}; +use anyhow::{bail, Context, Result}; use candle_core::{DType, Device, Tensor}; use either::Either; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; @@ -645,29 +645,29 @@ impl Pipeline for GGUFPipeline { flash_meta, flash_meta_full, } = *inputs.downcast().expect("Downcast failed."); + let paged_attn_meta = paged_attn_meta + .as_mut() + .with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + .map_err(|e| candle_core::Error::Msg(e.to_string()))?; let logits = match self.model { Model::Llama(ref model) => model.forward( &input_ids, &seqlen_offsets, seqlen_offsets_kernel, context_lens, - self.get_metadata().cache_engine.as_ref().map(|engine| { - ( - engine.get_kv_cache().clone(), - paged_attn_meta.as_mut().unwrap(), - ) - }), + self.get_metadata() + .cache_engine + .as_ref() + .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), )?, Model::Phi2(ref model) => model.forward( &input_ids, &seqlen_offsets, context_lens, - self.get_metadata().cache_engine.as_ref().map(|engine| { - ( - engine.get_kv_cache().clone(), - paged_attn_meta.as_mut().unwrap(), - ) - }), + self.get_metadata() + .cache_engine + .as_ref() + .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), )?, Model::XLoraLlama(ref model) => model.forward( &input_ids, @@ -685,12 +685,10 @@ impl Pipeline for GGUFPipeline { Model::Phi3(ref model) => model.forward( &input_ids, &seqlen_offsets, - self.get_metadata().cache_engine.as_ref().map(|engine| { - ( - engine.get_kv_cache().clone(), - paged_attn_meta.as_mut().unwrap(), - ) - }), + self.get_metadata() + .cache_engine + .as_ref() + .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), )?, Model::XLoraPhi3(ref model) => model.forward( &input_ids, @@ -709,12 +707,10 @@ impl Pipeline for GGUFPipeline { &input_ids, &seqlen_offsets, seqlen_offsets_kernel, - self.get_metadata().cache_engine.as_ref().map(|engine| { - ( - engine.get_kv_cache().clone(), - paged_attn_meta.as_mut().unwrap(), - ) - }), + self.get_metadata() + .cache_engine + .as_ref() + .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), )?, }; Ok(ForwardInputsResult::CausalGeneration { logits }) diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index aaab1ada5..5ea931468 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -32,7 +32,7 @@ use crate::{ normal_model_loader, xlora_model_loader, DeviceMapMetadata, PagedAttentionConfig, Pipeline, Topology, TryIntoDType, }; -use anyhow::Result; +use anyhow::{Context, Result}; use candle_core::{Device, Tensor, Var}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use mistralrs_quant::IsqType; @@ -515,6 +515,10 @@ impl Pipeline for NormalPipeline { flash_meta, flash_meta_full, } = *inputs.downcast().expect("Downcast failed."); + let paged_attn_meta = paged_attn_meta + .as_mut() + .with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + .map_err(|e| candle_core::Error::Msg(e.to_string()))?; let logits = match self.model.is_xlora() { false => self.model.forward( &input_ids, @@ -522,12 +526,10 @@ impl Pipeline for NormalPipeline { seqlen_offsets_kernel, context_lens, position_ids, - self.get_metadata().cache_engine.as_ref().map(|engine| { - ( - engine.get_kv_cache().clone(), - paged_attn_meta.as_mut().unwrap(), - ) - }), + self.get_metadata() + .cache_engine + .as_ref() + .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), &flash_meta, )?, true => self.model.xlora_forward( diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index 0bfaf1e06..7d883afa3 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -24,7 +24,7 @@ use crate::{ api_dir_list, api_get_file, get_paths, vision_normal_model_loader, AnyMoeExpertType, DeviceMapMetadata, Ordering, PagedAttentionConfig, Pipeline, Topology, TryIntoDType, }; -use anyhow::Result; +use anyhow::{Context, Result}; use candle_core::{Device, Tensor, Var}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use mistralrs_quant::IsqType; @@ -411,6 +411,10 @@ impl Pipeline for VisionPipeline { mut paged_attn_meta, flash_meta, } = *inputs.downcast::().expect("Downcast failed."); + let paged_attn_meta = paged_attn_meta + .as_mut() + .with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + .map_err(|e| candle_core::Error::Msg(e.to_string()))?; let logits = self.model.forward( &input_ids, pixel_values, @@ -419,12 +423,10 @@ impl Pipeline for VisionPipeline { context_lens, position_ids, model_specific_args, - self.get_metadata().cache_engine.as_ref().map(|engine| { - ( - engine.get_kv_cache().clone(), - paged_attn_meta.as_mut().unwrap(), - ) - }), + self.get_metadata() + .cache_engine + .as_ref() + .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), &flash_meta, )?; Ok(ForwardInputsResult::CausalGeneration { logits })