From a3d76e913ebec6bc3d97d9672209af931c3e3249 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 8 May 2024 05:22:49 -0400 Subject: [PATCH] Warmup pass for mistralrs bench --- mistralrs-bench/src/main.rs | 43 +++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index ac1d278ae..f70343efe 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -209,6 +209,44 @@ fn print_usage(model: &str, device: &Device, results: Vec) { print_stdout(table).expect("print table"); } +fn warmup_run(mistralrs: Arc) { + let sampling_params = SamplingParams { + temperature: Some(0.1), + top_k: Some(32), + top_p: Some(0.1), + top_n_logprobs: 0, + frequency_penalty: Some(0.1), + presence_penalty: Some(0.1), + max_len: Some(5), + stop_toks: None, + logits_bias: None, + n_choices: 1, + }; + let sender = mistralrs.get_sender(); + let (tx, mut rx) = channel(10_000); + + let req = Request { + id: mistralrs.next_request_id(), + messages: RequestMessage::Completion { + text: "Hello!".to_string(), + echo_prompt: false, + best_of: 1, + }, + sampling_params: sampling_params.clone(), + response: tx, + return_logprobs: false, + is_streaming: false, + constraint: Constraint::None, + suffix: None, + }; + + sender + .blocking_send(req.clone()) + .expect("Expected receiver."); + + let _ = rx.blocking_recv(); +} + #[derive(Parser)] #[command(version, about, long_about = None)] struct Args { @@ -311,6 +349,11 @@ fn main() -> anyhow::Result<()> { .with_disable_eos_stop(true) .build(); + info!("Starting warmup run."); + warmup_run(mistralrs.clone()); + info!("Finished warmup run."); + info!("Starting benchmarks."); + for concurrency in args.concurrency.as_ref().unwrap() { let mut results = vec![]; if args.n_gen > 0 {