Skip to content

Commit

Permalink
Apply clippy (#83)
Browse files Browse the repository at this point in the history
* Fix some warnings

* Clippy
  • Loading branch information
EricLBuehler authored Aug 20, 2024
1 parent 584f2d4 commit 20703f8
Show file tree
Hide file tree
Showing 24 changed files with 222 additions and 280 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,6 @@
"__bit_reference": "cpp",
"__functional_base": "cpp",
"__memory": "cpp"
}
},
"rust-analyzer.cargo.features": ["cuda"],
}
6 changes: 2 additions & 4 deletions kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
pub const COPY_BLOCKS_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
pub const COPY_BLOCKS_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
pub mod ffi;
2 changes: 1 addition & 1 deletion src/backend/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ pub unsafe fn copy_blocks(
COPY_BLOCKS_KERNEL_NAME,
key_caches.first().unwrap().dtype(),
None,
&dev,
dev,
));

try_api!(unsafe {
Expand Down
4 changes: 2 additions & 2 deletions src/backend/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl candle::CustomOp1 for PagedAttention {
///
/// * `q` - Query tensor with shape `(num_sequences, num_heads_q, head_size)`.
/// * `key_cache` - Key cache paged tensor of shape `(num_blocks, num_heads_kv, head_size / x, block_size, x)`
/// with `x` being the size of an element in bytes.
/// with `x` being the size of an element in bytes.
/// * `value_cache` - Value cache paged tensor of shape `(num_blocks, num_heads_kv, head_size, block_size)`.
/// * `block_tables` - Padded table associating blocks to each sequence of shape `(num_sequences, max_context_len // block_size)`
/// * `context_lens` - Tensor associating lengths to each sequence of shape `(num_sequences)`
Expand Down Expand Up @@ -440,7 +440,7 @@ fn update_cache<
/// * `key` - Key tensor of shape `(num_tokens, num_heads, head_size)`.
/// * `value` - Value tensor of shape `(num_tokens, num_heads, head_size)`.
/// * `key_cache` - Key cache paged tensor of shape `(num_blocks, num_heads, head_size / x, block_size, x)`
/// with `x` being the size of an element in bytes.
/// with `x` being the size of an element in bytes.
/// * `value_cache` - Value cache paged tensor of shape `(num_blocks, num_heads, head_size, block_size)`.
/// * `slot_mapping` - Mapping associating a slot to each token of shape `(num_tokens)`.
pub fn reshape_and_cache(
Expand Down
102 changes: 31 additions & 71 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#![warn(clippy::cast_lossless)]
use std::fmt::Display;

use candle::Result;
use candle_core as candle;
use clap::Subcommand;
Expand Down Expand Up @@ -190,30 +192,12 @@ pub enum ModelSelected {
},
}

impl ToString for ModelSelected {
fn to_string(&self) -> String {
impl Display for ModelSelected {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelSelected::Llama {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
quant: _,
} => "llama".to_string(),
ModelSelected::Llama3 {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
quant: _,
} => "llama3".to_string(),
ModelSelected::Phi2 {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
quant: _,
} => "phi2".to_string(),
ModelSelected::Llama { .. } => write!(f, "llama"),
ModelSelected::Llama3 { .. } => write!(f, "llama3"),
ModelSelected::Phi2 { .. } => write!(f, "phi2"),
ModelSelected::Phi3 {
repeat_last_n: _,
temperature: _,
Expand All @@ -222,7 +206,7 @@ impl ToString for ModelSelected {
penalty: _,
max_gen_tokens: _,
quant: _,
} => "phi3".to_string(),
} => write!(f, "phi3"),
ModelSelected::Qwen2 {
repeat_last_n: _,
temperature: _,
Expand All @@ -231,35 +215,11 @@ impl ToString for ModelSelected {
penalty: _,
max_gen_tokens: _,
quant: _,
} => "qwen2".to_string(),
ModelSelected::Gemma {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
quant: _,
} => "gemma".to_string(),
ModelSelected::Mistral {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
quant: _,
} => "mistral".to_string(),
ModelSelected::Yi {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
quant: _,
} => "yi".to_string(),
ModelSelected::StableLM {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
quant: _,
} => "stablelm".to_string(),
} => write!(f, "qwen2"),
ModelSelected::Gemma { .. } => write!(f, "gemma"),
ModelSelected::Mistral { .. } => write!(f, "mistral"),
ModelSelected::Yi { .. } => write!(f, "yi"),
ModelSelected::StableLM { .. } => write!(f, "stablelm"),
}
}
}
Expand Down Expand Up @@ -321,8 +281,8 @@ pub fn get_model_loader(
),
"llama".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"meta-llama/Llama-2-7b-chat-hf".to_string()
},
Expand All @@ -346,8 +306,8 @@ pub fn get_model_loader(
),
"llama3".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()
},
Expand All @@ -371,8 +331,8 @@ pub fn get_model_loader(
),
"phi2".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"microsoft/microsoft/phi-2".to_string()
},
Expand All @@ -398,8 +358,8 @@ pub fn get_model_loader(
),
"phi3".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"microsoft/Phi-3-mini-4k-instruct".to_string()
},
Expand All @@ -425,8 +385,8 @@ pub fn get_model_loader(
),
"qwen2".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"Qwen/Qwen1.5-1.8B-Chat".to_string()
},
Expand All @@ -450,8 +410,8 @@ pub fn get_model_loader(
),
"gemma".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"google/gemma-2b-it".to_string()
},
Expand All @@ -475,8 +435,8 @@ pub fn get_model_loader(
),
"mistral".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"mistralai/Mistral-7B-Instruct-v0.3".to_string()
},
Expand All @@ -501,8 +461,8 @@ pub fn get_model_loader(
),
"yi".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"01-ai/Yi-6B-Chat".to_string()
},
Expand All @@ -527,8 +487,8 @@ pub fn get_model_loader(
),
"stablelm".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
if let Some(model_id) = model_id {
model_id
} else {
"stabilityai/stablelm-zephyr-3b".to_string()
},
Expand Down
1 change: 0 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use axum::{
Router,
};
use candle_core::{DType, Device};
use candle_examples;
use candle_vllm::openai::openai_server::chat_completions;
use candle_vllm::openai::pipelines::llm_engine::LLMEngine;
use candle_vllm::openai::pipelines::pipeline::DefaultModelPaths;
Expand Down
13 changes: 6 additions & 7 deletions src/openai/conversation/default_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl Conversation for DefaultConversation {
if let Some(message) = message {
accum += &format!("[INST] {message} [/INST]");
} else {
accum += &format!("[INST] [/INST]");
accum += "[INST] [/INST]";
}
} else if _role.clone() == self.roles.1 {
//assistant message
Expand All @@ -260,8 +260,7 @@ impl Conversation for DefaultConversation {
"<|start_header_id|>user<|end_header_id|>\n\n {message} <|eot_id|>"
);
} else {
accum +=
&format!("<|start_header_id|>user<|end_header_id|>\n\n <|eot_id|>");
accum += "<|start_header_id|>user<|end_header_id|>\n\n <|eot_id|>";
}
} else if _role.clone() == self.roles.1 {
//assistant message
Expand All @@ -284,7 +283,7 @@ impl Conversation for DefaultConversation {
if let Some(message) = message {
accum += &format!("<|user|> {message}<|end|>");
} else {
accum += &format!("<|user|> <|end|");
accum += "<|user|> <|end|";
}
} else if _role.clone() == self.roles.1 {
//assistant message
Expand All @@ -307,7 +306,7 @@ impl Conversation for DefaultConversation {
if let Some(message) = message {
accum += &format!("<|im_start|>user\n {message} <|im_end|>");
} else {
accum += &format!("<|im_start|> <|im_end|>");
accum += "<|im_start|> <|im_end|>";
}
} else if _role.clone() == self.roles.1 {
//assistant message
Expand All @@ -323,7 +322,7 @@ impl Conversation for DefaultConversation {

SeparatorStyle::Gemma => {
let mut accum = "".to_string();
for (_, message) in self.messages.iter().enumerate() {
for message in self.messages.iter() {
let Message((_role, message)) = message;
if let Some(message) = message {
accum +=
Expand All @@ -345,7 +344,7 @@ impl Conversation for DefaultConversation {
if let Some(message) = message {
accum += &format!("<|user|>user\n {message}<|endoftext|>");
} else {
accum += &format!("<|user|> <|endoftext|>");
accum += "<|user|> <|endoftext|>";
}
} else if _role.clone() == self.roles.1 {
//assistant message
Expand Down
12 changes: 6 additions & 6 deletions src/openai/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl GemmaConfig {
eos_token_id: super::TokenID(Either::Left(Some(self.eos_token_id as u32))),
max_seq_len: self.max_position_embeddings.unwrap_or(4096),
sliding_window: None,
hidden_act: hidden_act,
hidden_act,
tie_word_embeddings: false,
rope_scaling: None,
original_max_position_embeddings: None,
Expand Down Expand Up @@ -111,7 +111,7 @@ impl RotaryEmbedding {
&self,
q: &Tensor,
k: &Tensor,
input_positions: &Vec<Vec<usize>>,
input_positions: &[Vec<usize>],
) -> Result<(Tensor, Tensor)> {
let (b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let mut q_embeds = Vec::new();
Expand Down Expand Up @@ -255,7 +255,7 @@ impl Attention {
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
input_positions: &Vec<Vec<usize>>,
input_positions: &[Vec<usize>],
cache: Option<(&Tensor, &Tensor)>,
input_metadata: &mut InputMetadata,
) -> Result<Tensor> {
Expand Down Expand Up @@ -350,7 +350,7 @@ impl DecoderLayer {
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
input_positions: &Vec<Vec<usize>>,
input_positions: &[Vec<usize>],
cache: Option<(&Tensor, &Tensor)>,
input_metadata: &mut InputMetadata,
) -> Result<Tensor> {
Expand Down Expand Up @@ -401,7 +401,7 @@ impl Gemma {
norm,
lm_head,
device: device.clone(),
dtype: dtype,
dtype,
hidden_size: cfg.hidden_size,
cfg: cfg.clone(),
})
Expand All @@ -419,7 +419,7 @@ impl Gemma {
pub fn forward(
&mut self,
input_ids: &Tensor,
input_positions: &Vec<Vec<usize>>,
input_positions: &[Vec<usize>],
kv_caches: Option<&Vec<(Tensor, Tensor)>>,
input_metadata: &mut InputMetadata,
) -> Result<Tensor> {
Expand Down
7 changes: 2 additions & 5 deletions src/openai/models/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl QLinear {
Self {
inner: QMatMul::QTensor(Arc::new(w)),
bias: bx,
dtype: dtype,
dtype,
}
}

Expand All @@ -225,10 +225,7 @@ impl QLinear {
_ => panic!("Unsupported GGML data type!"),
};
let qtensor = QTensor::quantize(weight, ggml_dtype).unwrap();
let qbias = match linear.bias() {
Some(b) => Some(b.clone()),
_ => None,
};
let qbias = linear.bias().cloned();

QLinear::from_qparts_x(qtensor, qbias, dtype)
}
Expand Down
Loading

0 comments on commit 20703f8

Please sign in to comment.