-
Notifications
You must be signed in to change notification settings - Fork 950
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add quantized version of qwen2 and corresponding example for qwen2-instruct * fix quantized qwen2 clippy error
- Loading branch information
Showing
4 changed files
with
641 additions
and
0 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
candle-examples/examples/quantized-qwen2-instruct/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# candle-quantized-qwen2-instruct | ||
|
||
[Qwen2]((https://qwenlm.github.io/blog/qwen2/)) is an upgraded version of Qwen1.5, released by Alibaba Cloud. | ||
|
||
## Running the example | ||
|
||
```bash | ||
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N." | ||
``` | ||
|
||
0.5b, 1.5b, 7b and 72b models are available via `--model` argument. |
306 changes: 306 additions & 0 deletions
306
candle-examples/examples/quantized-qwen2-instruct/main.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
|
||
use clap::{Parser, ValueEnum}; | ||
use std::io::Write; | ||
use tokenizers::Tokenizer; | ||
|
||
use candle::quantized::gguf_file; | ||
use candle::Tensor; | ||
use candle_transformers::generation::{LogitsProcessor, Sampling}; | ||
|
||
use candle_examples::token_output_stream::TokenOutputStream; | ||
use candle_transformers::models::quantized_qwen2::ModelWeights as Qwen2; | ||
|
||
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; | ||
|
||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] | ||
enum Which { | ||
#[value(name = "0.5b")] | ||
W2_0_5b, | ||
#[value(name = "1.5b")] | ||
W2_1_5b, | ||
#[value(name = "7b")] | ||
W2_7b, | ||
#[value(name = "72b")] | ||
W2_72b, | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp | ||
#[arg(long)] | ||
model: Option<String>, | ||
|
||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way | ||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens | ||
/// is preserved. | ||
#[arg(long)] | ||
prompt: Option<String>, | ||
|
||
/// The length of the sample to generate (in tokens). | ||
#[arg(short = 'n', long, default_value_t = 1000)] | ||
sample_len: usize, | ||
|
||
/// The tokenizer config in json format. | ||
#[arg(long)] | ||
tokenizer: Option<String>, | ||
|
||
/// The temperature used to generate samples, use 0 for greedy sampling. | ||
#[arg(long, default_value_t = 0.8)] | ||
temperature: f64, | ||
|
||
/// Nucleus sampling probability cutoff. | ||
#[arg(long)] | ||
top_p: Option<f64>, | ||
|
||
/// Only sample among the top K samples. | ||
#[arg(long)] | ||
top_k: Option<usize>, | ||
|
||
/// The seed to use when generating random samples. | ||
#[arg(long, default_value_t = 299792458)] | ||
seed: u64, | ||
|
||
/// Enable tracing (generates a trace-timestamp.json file). | ||
#[arg(long)] | ||
tracing: bool, | ||
|
||
/// Process prompt elements separately. | ||
#[arg(long)] | ||
split_prompt: bool, | ||
|
||
/// Run on CPU rather than GPU even if a GPU is available. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Penalty to be applied for repeating tokens, 1. means no penalty. | ||
#[arg(long, default_value_t = 1.1)] | ||
repeat_penalty: f32, | ||
|
||
/// The context size to consider for the repeat penalty. | ||
#[arg(long, default_value_t = 64)] | ||
repeat_last_n: usize, | ||
|
||
/// The model size to use. | ||
#[arg(long, default_value = "0.5b")] | ||
which: Which, | ||
} | ||
|
||
impl Args { | ||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> { | ||
let tokenizer_path = match &self.tokenizer { | ||
Some(config) => std::path::PathBuf::from(config), | ||
None => { | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let repo = match self.which { | ||
Which::W2_0_5b => "Qwen/Qwen2-0.5B-Instruct", | ||
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct", | ||
Which::W2_7b => "Qwen/Qwen2-7B-Instruct", | ||
Which::W2_72b => "Qwen/Qwen2-72B-Instruct", | ||
}; | ||
let api = api.model(repo.to_string()); | ||
api.get("tokenizer.json")? | ||
} | ||
}; | ||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) | ||
} | ||
|
||
fn model(&self) -> anyhow::Result<std::path::PathBuf> { | ||
let model_path = match &self.model { | ||
Some(config) => std::path::PathBuf::from(config), | ||
None => { | ||
let (repo, filename, revision) = match self.which { | ||
Which::W2_0_5b => ( | ||
"Qwen/Qwen2-0.5B-Instruct-GGUF", | ||
"qwen2-0_5b-instruct-q4_0.gguf", | ||
"main", | ||
), | ||
Which::W2_1_5b => ( | ||
"Qwen/Qwen2-1.5B-Instruct-GGUF", | ||
"qwen2-1_5b-instruct-q4_0.gguf", | ||
"main", | ||
), | ||
Which::W2_7b => ( | ||
"Qwen/Qwen2-7B-Instruct-GGUF", | ||
"qwen2-7b-instruct-q4_0.gguf", | ||
"main", | ||
), | ||
Which::W2_72b => ( | ||
"Qwen/Qwen2-72B-Instruct-GGUF", | ||
"qwen2-72b-instruct-q4_0.gguf", | ||
"main", | ||
), | ||
}; | ||
let api = hf_hub::api::sync::Api::new()?; | ||
api.repo(hf_hub::Repo::with_revision( | ||
repo.to_string(), | ||
hf_hub::RepoType::Model, | ||
revision.to_string(), | ||
)) | ||
.get(filename)? | ||
} | ||
}; | ||
Ok(model_path) | ||
} | ||
} | ||
|
||
fn format_size(size_in_bytes: usize) -> String { | ||
if size_in_bytes < 1_000 { | ||
format!("{}B", size_in_bytes) | ||
} else if size_in_bytes < 1_000_000 { | ||
format!("{:.2}KB", size_in_bytes as f64 / 1e3) | ||
} else if size_in_bytes < 1_000_000_000 { | ||
format!("{:.2}MB", size_in_bytes as f64 / 1e6) | ||
} else { | ||
format!("{:.2}GB", size_in_bytes as f64 / 1e9) | ||
} | ||
} | ||
|
||
fn main() -> anyhow::Result<()> { | ||
use tracing_chrome::ChromeLayerBuilder; | ||
use tracing_subscriber::prelude::*; | ||
|
||
let args = Args::parse(); | ||
let _guard = if args.tracing { | ||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); | ||
tracing_subscriber::registry().with(chrome_layer).init(); | ||
Some(guard) | ||
} else { | ||
None | ||
}; | ||
|
||
println!( | ||
"avx: {}, neon: {}, simd128: {}, f16c: {}", | ||
candle::utils::with_avx(), | ||
candle::utils::with_neon(), | ||
candle::utils::with_simd128(), | ||
candle::utils::with_f16c() | ||
); | ||
println!( | ||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", | ||
args.temperature, args.repeat_penalty, args.repeat_last_n | ||
); | ||
|
||
let model_path = args.model()?; | ||
let mut file = std::fs::File::open(&model_path)?; | ||
let start = std::time::Instant::now(); | ||
let device = candle_examples::device(args.cpu)?; | ||
|
||
let mut model = { | ||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; | ||
let mut total_size_in_bytes = 0; | ||
for (_, tensor) in model.tensor_infos.iter() { | ||
let elem_count = tensor.shape.elem_count(); | ||
total_size_in_bytes += | ||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); | ||
} | ||
println!( | ||
"loaded {:?} tensors ({}) in {:.2}s", | ||
model.tensor_infos.len(), | ||
&format_size(total_size_in_bytes), | ||
start.elapsed().as_secs_f32(), | ||
); | ||
Qwen2::from_gguf(model, &mut file, &device)? | ||
}; | ||
println!("model built"); | ||
|
||
let tokenizer = args.tokenizer()?; | ||
let mut tos = TokenOutputStream::new(tokenizer); | ||
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); | ||
let prompt_str = format!( | ||
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", | ||
prompt_str | ||
); | ||
print!("formatted instruct prompt: {}", &prompt_str); | ||
let tokens = tos | ||
.tokenizer() | ||
.encode(prompt_str, true) | ||
.map_err(anyhow::Error::msg)?; | ||
let tokens = tokens.get_ids(); | ||
let to_sample = args.sample_len.saturating_sub(1); | ||
let mut all_tokens = vec![]; | ||
let mut logits_processor = { | ||
let temperature = args.temperature; | ||
let sampling = if temperature <= 0. { | ||
Sampling::ArgMax | ||
} else { | ||
match (args.top_k, args.top_p) { | ||
(None, None) => Sampling::All { temperature }, | ||
(Some(k), None) => Sampling::TopK { k, temperature }, | ||
(None, Some(p)) => Sampling::TopP { p, temperature }, | ||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, | ||
} | ||
}; | ||
LogitsProcessor::from_sampling(args.seed, sampling) | ||
}; | ||
let start_prompt_processing = std::time::Instant::now(); | ||
let mut next_token = if !args.split_prompt { | ||
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?; | ||
let logits = model.forward(&input, 0)?; | ||
let logits = logits.squeeze(0)?; | ||
logits_processor.sample(&logits)? | ||
} else { | ||
let mut next_token = 0; | ||
for (pos, token) in tokens.iter().enumerate() { | ||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; | ||
let logits = model.forward(&input, pos)?; | ||
let logits = logits.squeeze(0)?; | ||
next_token = logits_processor.sample(&logits)? | ||
} | ||
next_token | ||
}; | ||
let prompt_dt = start_prompt_processing.elapsed(); | ||
all_tokens.push(next_token); | ||
if let Some(t) = tos.next_token(next_token)? { | ||
print!("{t}"); | ||
std::io::stdout().flush()?; | ||
} | ||
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap(); | ||
let start_post_prompt = std::time::Instant::now(); | ||
let mut sampled = 0; | ||
for index in 0..to_sample { | ||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; | ||
let logits = model.forward(&input, tokens.len() + index)?; | ||
let logits = logits.squeeze(0)?; | ||
let logits = if args.repeat_penalty == 1. { | ||
logits | ||
} else { | ||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); | ||
candle_transformers::utils::apply_repeat_penalty( | ||
&logits, | ||
args.repeat_penalty, | ||
&all_tokens[start_at..], | ||
)? | ||
}; | ||
next_token = logits_processor.sample(&logits)?; | ||
all_tokens.push(next_token); | ||
if let Some(t) = tos.next_token(next_token)? { | ||
print!("{t}"); | ||
std::io::stdout().flush()?; | ||
} | ||
sampled += 1; | ||
if next_token == eos_token { | ||
break; | ||
}; | ||
} | ||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { | ||
print!("{rest}"); | ||
} | ||
std::io::stdout().flush()?; | ||
let dt = start_post_prompt.elapsed(); | ||
println!( | ||
"\n\n{:4} prompt tokens processed: {:.2} token/s", | ||
tokens.len(), | ||
tokens.len() as f64 / prompt_dt.as_secs_f64(), | ||
); | ||
println!( | ||
"{sampled:4} tokens generated: {:.2} token/s", | ||
sampled as f64 / dt.as_secs_f64(), | ||
); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.