Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server throughput logging #608

Merged
merged 8 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub struct Engine {
prefix_cacher: PrefixCacheManager,
is_debug: bool,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
}

impl Engine {
Expand Down Expand Up @@ -84,9 +85,16 @@ impl Engine {
),
is_debug: DEBUG.load(Ordering::Relaxed),
disable_eos_stop,
throughput_logging_enabled: false,
}
}

// TODO(EricLBuehler): On v0.3.0 move this into the Engine constructor
/// Enable throughput logging.
pub fn enable_throughput_logging(&mut self) {
self.throughput_logging_enabled = true;
}

pub async fn run(&mut self) {
let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
let mut last_completion_ids: Vec<usize> = vec![];
Expand All @@ -101,7 +109,10 @@ impl Engine {
SchedulerOutput::DefaultScheduler {
output: mut scheduled,
} => {
let mut prompt_ts = None;
let mut completion_ts = None;
if scheduled.completion.len() > 0 {
let throughput_start = Instant::now();
let current_completion_ids: Vec<usize> =
scheduled.completion.iter().map(|seq| *seq.id()).collect();
let res = {
Expand Down Expand Up @@ -152,11 +163,22 @@ impl Engine {
'lp,
self.prefix_cacher
);
let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
completion_ts = Some(
scheduled.completion.len() as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64(),
);
}

last_completion_ids = current_completion_ids;
}

if scheduled.prompt.len() > 0 {
let throughput_start = Instant::now();
let logits = {
let mut pipeline = get_mut_arcmutex!(self.pipeline);

Expand Down Expand Up @@ -202,6 +224,20 @@ impl Engine {
'lp,
self.prefix_cacher
);
let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
prompt_ts = Some(
scheduled
.prompt
.iter()
.map(|seq| seq.get_toks().len())
.sum::<usize>() as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64(),
);
}

for seq in scheduled.prompt.iter_mut() {
seq.set_state(SequenceState::RunningCompletion);
Expand Down Expand Up @@ -245,6 +281,21 @@ impl Engine {
}
}

if self.throughput_logging_enabled {
match (prompt_ts, completion_ts) {
(Some(prompt), Some(completion)) => {
info!("Throughput (scheduler V1): Prompt: {prompt} T/s Completion {completion} T/s");
}
(None, Some(completion)) => {
info!("Throughput (scheduler V1): Completion {completion} T/s");
}
(Some(prompt), None) => {
info!("Throughput (scheduler V1): Prompt: {prompt} T/s");
}
(None, None) => (),
}
}

if scheduled.prompt.len() == 0
&& scheduled.completion.len() == 0
&& self.scheduler.waiting_len() == 0
Expand All @@ -257,6 +308,8 @@ impl Engine {
}
SchedulerOutput::PagedAttention { mut output } => {
if !output.scheduled.is_empty() {
let throughput_start = Instant::now();

let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();

let mut guards = output
Expand Down Expand Up @@ -330,6 +383,21 @@ impl Engine {
}
}

let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
let n_toks = if is_prompt {
guards.iter().map(|seq| seq.get_toks().len()).sum::<usize>()
} else {
guards.len()
};
let ts = n_toks as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64();
info!("Throughput (scheduler V2): {ts} T/s");
}

if is_prompt {
for mut seq in guards {
let now = SystemTime::now()
Expand Down
15 changes: 15 additions & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ struct RebootState {
no_prefix_cache: bool,
prefix_cache_n: usize,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
}

#[derive(Debug)]
Expand Down Expand Up @@ -145,6 +146,7 @@ pub struct MistralRsBuilder {
prefix_cache_n: Option<usize>,
disable_eos_stop: Option<bool>,
gemm_full_precision_f16: Option<bool>,
throughput_logging_enabled: Option<()>,
}

impl MistralRsBuilder {
Expand All @@ -159,6 +161,7 @@ impl MistralRsBuilder {
prefix_cache_n: None,
disable_eos_stop: None,
gemm_full_precision_f16: None,
throughput_logging_enabled: None,
}
}
pub fn with_log(mut self, log: String) -> Self {
Expand Down Expand Up @@ -193,6 +196,10 @@ impl MistralRsBuilder {
self.gemm_full_precision_f16 = Some(gemm_full_precision);
self
}
pub fn with_throughput_logging(mut self) -> Self {
self.throughput_logging_enabled = Some(());
self
}

pub fn build(self) -> Arc<MistralRs> {
MistralRs::new(self)
Expand Down Expand Up @@ -248,6 +255,7 @@ impl MistralRs {
prefix_cache_n,
disable_eos_stop,
gemm_full_precision_f16,
throughput_logging_enabled,
} = config;

let model_supports_reduced_gemm = match pipeline.try_lock().unwrap().category() {
Expand All @@ -273,6 +281,7 @@ impl MistralRs {
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
throughput_logging_enabled: throughput_logging_enabled.is_some(),
};

let (tx, rx) = channel(10_000);
Expand All @@ -293,6 +302,9 @@ impl MistralRs {
prefix_cache_n,
disable_eos_stop,
);
if throughput_logging_enabled.is_some() {
engine.enable_throughput_logging();
}
engine.run().await;
});
});
Expand Down Expand Up @@ -343,6 +355,9 @@ impl MistralRs {
reboot_state.prefix_cache_n,
reboot_state.disable_eos_stop,
);
if reboot_state.throughput_logging_enabled {
engine.enable_throughput_logging();
}
engine.run().await;
});
});
Expand Down
21 changes: 16 additions & 5 deletions mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ struct Args {
/// Disable PagedAttention on CUDA.
#[arg(long = "no-paged-attn", default_value_t = false)]
no_paged_attn: bool,

/// Enable server throughput logging when not using interactive mode
#[arg(long = "throughput", default_value_t = false)]
throughput_log: bool,
}

#[utoipa::path(
Expand Down Expand Up @@ -403,23 +407,30 @@ async fn main() -> Result<()> {
method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
}
};
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config)
// Throughput logging in the server
let builder = MistralRsBuilder::new(pipeline, scheduler_config)
.with_opt_log(args.log)
.with_truncate_sequence(args.truncate_sequence)
.with_no_kv_cache(args.no_kv_cache)
.with_prefix_cache_n(args.prefix_cache_n)
.build();
.with_prefix_cache_n(args.prefix_cache_n);

if args.interactive_mode && args.vision_interactive_mode {
anyhow::bail!("Interactive mode and vision interactive mode are exclusive.");
} else if args.interactive_mode {
interactive_mode(mistralrs, false).await;
interactive_mode(builder.build(), false).await;
return Ok(());
} else if args.vision_interactive_mode {
interactive_mode(mistralrs, true).await;
interactive_mode(builder.build(), true).await;
return Ok(());
}

let builder = if args.throughput_log {
builder.with_throughput_logging()
} else {
builder
};
let mistralrs = builder.build();

let port = args.port.expect("Expected port to be specified.");

let app = get_router(mistralrs);
Expand Down
Loading