Skip to content

Commit

Permalink
Add percentage utilization support to paged attn (#574)
Browse files Browse the repository at this point in the history
* Support MISTRALRS_DEBUG=1 in paged attn

* Try to handle deadlock

* Try to fix deadlock again

* Add percentage utilization support

* Update docs

* Clippy
  • Loading branch information
EricLBuehler committed Jul 15, 2024
1 parent 639268c commit ebe032e
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 76 deletions.
6 changes: 3 additions & 3 deletions docs/PAGED_ATTENTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Our Paged Attention implementation has 2 inputs: GPU KV cache memory size, and b

> Note: The default block size if not specified is 32.
> Warning: When using dynamic adapter activation or sending re-ISQ requests, it may trigger OOM because the Paged Attention KV cache has already been allocated. To counter this, either set the KV cache memory to a lower amount (recommended) or disable paged attention.
> Note: if OOM happens (this can be caused by a variety of factors including adapter activation, re-ISQ, and others), it happens because the Paged Attention KV cache has already been allocated. To counter this, either set the KV cache memory to a lower amount or usage percentage (recommended) or disable paged attention entirely for a dynamically allocated cache.
**There are more features being added to this:**
- GGML model support
Expand All @@ -23,14 +23,14 @@ Our Paged Attention implementation has 2 inputs: GPU KV cache memory size, and b
## Using the CLI

Add the `--pa-gpu-mem` and `--pa-blk-size` parameters before the model kind selector. The GPU memory is in MBs and the block size means the number of tokens per block. These parameters may be passed on any supported model type.
Add the `--pa-gpu-mem`/`--pa-gpu-mem-usage` and `--pa-blk-size` parameters before the model kind selector. The GPU memory is in MBs and the block size means the number of tokens per block. These parameters may be passed on any supported model type.

```
cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 --isq Q4K plain -m microsoft/Phi-3-mini-128k-instruct -a phi3
```

```
cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf
cargo run --release --features cuda -- -i --pa-gpu-mem-usage .95 --pa-blk-size 32 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf
```

## Using the Rust API
Expand Down
37 changes: 30 additions & 7 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use candle_core::Device;
use clap::Parser;
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
use either::Either;
use mistralrs_core::{
initialize_logging, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata,
DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, ModelDType,
Expand Down Expand Up @@ -278,11 +279,17 @@ struct Args {
#[arg(short, long, value_parser, value_delimiter = ';')]
num_device_layers: Option<Vec<String>>,

/// GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is CUDA, it will default to to the
/// available GPU memory. Paged Attention is only supported on CUDA and is always automatically activated.
/// GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is CUDA, it will default to
/// using `pa-gpu-mem-usage` set to `0.9`. Paged Attention is only supported on CUDA and is always automatically activated.
#[arg(long = "pa-gpu-mem")]
paged_attn_gpu_mem: Option<usize>,

/// Percentage of GPU memory to utilize after allocation of KV cache with Paged Attention, from 0 to 1.
/// If this is not set and the device is CUDA, it will default to `0.9`. Paged Attention is only supported on CUDA and is always automatically activated.
/// This is always used over `pa-gpu-mem` if both are specified.
#[arg(long = "pa-gpu-mem-usage")]
paged_attn_gpu_mem_usage: Option<f32>,

/// Block size (number of tokens per block) for Paged Attention. If this is not set and the device is CUDA, it will default to 32.
/// Paged Attention is only supported on CUDA and is always automatically activated.
#[arg(long = "pa-blk-size")]
Expand Down Expand Up @@ -373,16 +380,32 @@ fn main() -> anyhow::Result<()> {
let cache_config = match (
args.paged_attn_block_size,
args.paged_attn_gpu_mem,
args.paged_attn_gpu_mem_usage,
device.is_cuda(),
args.no_paged_attn,
) {
(block_size, None, true, false) => Some(PagedAttentionConfig::new(
block_size, 512, None, // Autodetermine KV cache size
(block_size, None, None, true, false) => Some(PagedAttentionConfig::new(
block_size,
512,
Either::Right(0.9), // NOTE(EricLBuehler): default is to use 90% of memory
)?),
(block_size, Some(m), None, true, false) => {
Some(PagedAttentionConfig::new(block_size, 512, Either::Left(m))?)
}
(block_size, None, Some(f), true, false) => Some(PagedAttentionConfig::new(
block_size,
512,
Either::Right(f),
)?),
(block_size, Some(gpu_mem), _, false) => {
Some(PagedAttentionConfig::new(block_size, 512, Some(gpu_mem))?)
(block_size, Some(_m), Some(f), true, false) => {
info!("Both memory size and usage were specified, defaulting to the usage value.");
Some(PagedAttentionConfig::new(
block_size,
512,
Either::Right(f),
)?)
}
(_, _, _, _) => None,
(_, _, _, _, _) => None,
};

let pipeline = loader.load_model_from_hf(
Expand Down
25 changes: 13 additions & 12 deletions mistralrs-core/src/dummy_paged_attention/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub use block_engine_sequence::BlockEngineSequence;
pub use cache_engine::{CacheConfig, CacheEngine};
use candle_core::{DType, Device};
pub use config::{ModelConfigLike, ModelConfigMetadata};
use either::Either;
pub use layers::PagedAttention;
pub use scheduler::{
PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
Expand All @@ -29,14 +30,14 @@ use tracing::info;
pub struct PagedAttentionConfig {
pub(crate) block_size: Option<usize>,
pub(crate) mem_cpu: usize,
pub(crate) mem_gpu: Option<usize>,
pub(crate) mem_gpu: Either<usize, f32>,
}

impl PagedAttentionConfig {
pub fn new(
_block_size: Option<usize>,
_mem_cpu: usize,
_mem_gpu: Option<usize>,
_mem_gpu: Either<usize, f32>,
) -> anyhow::Result<Self> {
anyhow::bail!("PagedAttention is only supported for CUDA, compile with feature `cuda`.")
}
Expand Down Expand Up @@ -64,9 +65,9 @@ macro_rules! mb_to_blocks {
};
}

/// Memory values are in MBs. Specify block size or the default is 32.
/// Memory values are in MBs or a percentage in [0,1]. Specify block size or the default is 32.
pub fn calculate_cache_config(
mem_gpu: Option<usize>,
mem_gpu: Either<usize, f32>,
mem_cpu: usize,
block_size: Option<usize>,
dtype: DType,
Expand All @@ -79,15 +80,15 @@ pub fn calculate_cache_config(
}
let dtype_size = dtype.size_in_bytes();

#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let mem_gpu = match mem_gpu {
Some(v) => v,
None => {
let free = MemoryUsage.get_memory_available(device)? / SIZE_IN_MB;
info!(
"Automatically using {} MB for Paged Attention KV cache",
free - 512
);
free - 512
Either::Left(v) => v,
Either::Right(f) => {
let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32;
let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32 * f;
let size = (total - free) as usize;
info!("Allocating {size} MB for Paged Attention KV cache");
size
}
};

Expand Down
84 changes: 57 additions & 27 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,50 +257,80 @@ impl Engine {
}
SchedulerOutput::PagedAttention { mut output } => {
if !output.scheduled.is_empty() {
let mut pipeline = get_mut_arcmutex!(self.pipeline);

let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();

let block_tables = self.scheduler.block_tables().unwrap();
let block_size = self.scheduler.block_size().unwrap();

let metadata = PagedAttentionMeta {
block_tables,
block_size,
sliding_window: pipeline.get_metadata().sliding_window,
};

let mut guards = output
.scheduled
.iter_mut()
.map(|seq| seq.lock().unwrap())
.collect::<Vec<_>>();

let res = pipeline
.step(
&mut guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>(),
is_prompt,
&mut self.prefix_cacher,
self.disable_eos_stop,
rng.clone(),
CacheBackendMetadata::PagedAttention {
metadata,
blocks_to_copy: output.blocks_to_copy,
blocks_to_swap_in: output.blocks_to_swap_in,
blocks_to_swap_out: output.blocks_to_swap_out,
},
)
.await;
let mut guards_mut =
guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();

let res = {
let mut pipeline = get_mut_arcmutex!(self.pipeline);

let block_tables = self.scheduler.block_tables().unwrap();
let block_size = self.scheduler.block_size().unwrap();

let metadata = PagedAttentionMeta {
block_tables,
block_size,
sliding_window: pipeline.get_metadata().sliding_window,
};

pipeline
.step(
&mut guards_mut,
is_prompt,
&mut self.prefix_cacher,
self.disable_eos_stop,
rng.clone(),
CacheBackendMetadata::PagedAttention {
metadata,
blocks_to_copy: output.blocks_to_copy,
blocks_to_swap_in: output.blocks_to_swap_in,
blocks_to_swap_out: output.blocks_to_swap_out,
},
)
.await
};

handle_pipeline_forward_error!(
"step",
res,
&mut guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>(),
&mut guards_mut,
self.pipeline,
'lp,
self.prefix_cacher
);

if self.is_debug {
let ms_from_last_run = run_start.elapsed().as_secs_f64();
let total_len = guards.len();
if total_len > 0 {
let lengths = guards
.iter()
.map(|seq| seq.len().to_string())
.collect::<Vec<_>>()
.join(", ");

let (prompt_lengths, completion_lengths) = if is_prompt {
(lengths, "".to_string())
} else {
("".to_string(), lengths)
};

tracing::info!(
"Prompt[{}] Completion[{}] - {}ms",
prompt_lengths,
completion_lengths,
ms_from_last_run * 1000.,
);
}
}

if is_prompt {
for mut seq in guards {
let now = SystemTime::now()
Expand Down
25 changes: 13 additions & 12 deletions mistralrs-core/src/paged_attention/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub use block_engine_sequence::BlockEngineSequence;
pub use cache_engine::{CacheConfig, CacheEngine};
use candle_core::{DType, Device};
pub use config::{ModelConfigLike, ModelConfigMetadata};
use either::Either;
pub use layers::PagedAttention;
pub use scheduler::{
PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
Expand All @@ -29,14 +30,14 @@ use tracing::info;
pub struct PagedAttentionConfig {
pub(crate) block_size: Option<usize>,
pub(crate) mem_cpu: usize,
pub(crate) mem_gpu: Option<usize>,
pub(crate) mem_gpu: Either<usize, f32>,
}

impl PagedAttentionConfig {
pub fn new(
block_size: Option<usize>,
mem_cpu: usize,
mem_gpu: Option<usize>,
mem_gpu: Either<usize, f32>,
) -> anyhow::Result<Self> {
Ok(Self {
block_size,
Expand Down Expand Up @@ -68,9 +69,9 @@ macro_rules! mb_to_blocks {
};
}

/// Memory values are in MBs. Specify block size or the default is 32.
/// Memory values are in MBs or a percentage in [0,1]. Specify block size or the default is 32.
pub fn calculate_cache_config(
mem_gpu: Option<usize>,
mem_gpu: Either<usize, f32>,
mem_cpu: usize,
block_size: Option<usize>,
dtype: DType,
Expand All @@ -83,15 +84,15 @@ pub fn calculate_cache_config(
}
let dtype_size = dtype.size_in_bytes();

#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let mem_gpu = match mem_gpu {
Some(v) => v,
None => {
let free = MemoryUsage.get_memory_available(device)? / SIZE_IN_MB;
info!(
"Automatically using {} MB for Paged Attention KV cache",
free - 512
);
free - 512
Either::Left(v) => v,
Either::Right(f) => {
let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32;
let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32 * f;
let size = (total - free) as usize;
info!("Allocating {size} MB for Paged Attention KV cache");
size
}
};

Expand Down
24 changes: 24 additions & 0 deletions mistralrs-core/src/utils/memory_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,28 @@ impl MemoryUsage {
}
}
}

pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
match device {
Device::Cpu => {
let mut sys = System::new_all();
sys.refresh_cpu();
Ok(usize::try_from(sys.total_memory())? * KB_TO_BYTES)
}
#[cfg(feature = "cuda")]
Device::Cuda(_) => {
use candle_core::cuda_backend::WrapErr;
Ok(candle_core::cuda::cudarc::driver::result::mem_get_info()
.w()?
.1)
}
#[cfg(not(feature = "cuda"))]
Device::Cuda(_) => {
candle_core::bail!("Cannot get total memory for CUDA device")
}
Device::Metal(_) => {
candle_core::bail!("Cannot get total memory for Metal device")
}
}
}
}
6 changes: 3 additions & 3 deletions mistralrs-pyo3/mistralrs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class Runner:
num_device_layers: list[str] | None = None,
in_situ_quant: str | None = None,
anymoe_config: AnyMoeConfig | None = None,
pa_gpu_mem: int | None = None,
pa_gpu_mem: int | float | None = None,
pa_blk_size: int | None = None,
no_paged_attn: bool = False,
) -> None:
Expand All @@ -211,8 +211,8 @@ class Runner:
the corresponding number of layers.
- `in_situ_quant` sets the optional in-situ quantization for models that are not quantized (not GGUF or GGML).
- `anymoe_config` specifies the AnyMoE config. If this is set, then the model will be loaded as an AnyMoE model.
- `pa_gpu_mem` sets GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is
CUDA, it will default to to the available GPU memory. Paged Attention is only supported on CUDA and is always automatically activated.
- `pa_gpu_mem` sets GPU memory to allocate for KV cache with Paged Attention in MBs *OR* the percentage utilization, from 0 to 1. If this is not set and the device is
CUDA, it will default to using 90% of the total memory after allocation of the KV cache. Paged Attention is only supported on CUDA and is always automatically activated.
- `pa_blk_size` sets the block size (number of tokens per block) for Paged Attention. If this is not set and the device is CUDA,
it will default to 32. Paged Attention is only supported on CUDA and is always automatically activated.
- `no_paged_attn` disables Paged Attention on CUDA
Expand Down
10 changes: 6 additions & 4 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ impl Runner {
num_device_layers: Option<Vec<String>>,
in_situ_quant: Option<String>,
anymoe_config: Option<AnyMoeConfig>,
pa_gpu_mem: Option<usize>,
pa_gpu_mem: Option<Either<usize, f32>>,
pa_blk_size: Option<usize>,
no_paged_attn: bool,
) -> PyResult<Self> {
Expand Down Expand Up @@ -466,10 +466,12 @@ impl Runner {
// Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
let cache_config = match (pa_blk_size, pa_gpu_mem, device.is_cuda(), no_paged_attn) {
(block_size, None, true, false) => Some(PagedAttentionConfig::new(
block_size, 512, None, // Autodetermine KV cache size
block_size,
512,
Either::Right(0.9), // NOTE(EricLBuehler): default is to use 90% of memory
)?),
(block_size, Some(gpu_mem), _, false) => {
Some(PagedAttentionConfig::new(block_size, 512, Some(gpu_mem))?)
(block_size, Some(either), true, false) => {
Some(PagedAttentionConfig::new(block_size, 512, either)?)
}
(_, _, _, _) => None,
};
Expand Down
Loading

0 comments on commit ebe032e

Please sign in to comment.