From 041887942e8339af05a3e9aac3b1d02f7e31d546 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 24 Jan 2024 13:30:16 +0800 Subject: [PATCH] add segformer --- candle-examples/examples/segformer/README.md | 9 + candle-examples/examples/segformer/main.rs | 44 ++ candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/segformer.rs | 617 +++++++++++++++++++ 4 files changed, 671 insertions(+) create mode 100644 candle-examples/examples/segformer/README.md create mode 100644 candle-examples/examples/segformer/main.rs create mode 100644 candle-transformers/src/models/segformer.rs diff --git a/candle-examples/examples/segformer/README.md b/candle-examples/examples/segformer/README.md new file mode 100644 index 0000000000..c60684784f --- /dev/null +++ b/candle-examples/examples/segformer/README.md @@ -0,0 +1,9 @@ +# candle-segformer + +- [HuggingFace Segformer Model Card][segformer] +- [`mit-b0` - An encoder only pretrained model][encoder] +- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512] + +[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer +[encoder]: https://huggingface.co/nvidia/mit-b0 +[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512 diff --git a/candle-examples/examples/segformer/main.rs b/candle-examples/examples/segformer/main.rs new file mode 100644 index 0000000000..68ed65c37e --- /dev/null +++ b/candle-examples/examples/segformer/main.rs @@ -0,0 +1,44 @@ +use std::path::PathBuf; + +use candle::Module; +use candle_nn::VarBuilder; +use candle_transformers::models::segformer; +use clap::Parser; + +#[derive(Parser)] +#[clap(about, version, long_about = None)] +struct Args { + #[arg( + long, + help = "name of the huggingface hub model", + default_value = "nvidia/segformer-b0-finetuned-ade-512-512" + )] + model_name: String, + #[arg(long, help = "path to image as input")] + image: PathBuf, + #[arg(long, help = "use cpu")] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; + println!("loaded image {image:?}"); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(args.model_name); + let model_file = api.get("model.safetensors")?; + println!("model downloaded and loaded"); + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, &device)? }; + let config = Default::default(); + let num_labels = 150; + let model = segformer::SemanticSegmentationModel::new(&config, num_labels, vb)?; + let input = image.unsqueeze(0)?; + let segmentations = model.forward(&input)?; + println!( + "segmentation result shape {:?} which should match [1, num_labels, height/4, width/4]", + segmentations.shape() + ); + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a94fd07a06..0b091f134b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -30,6 +30,7 @@ pub mod quantized_stable_lm; pub mod quantized_t5; pub mod repvgg; pub mod resnet; +pub mod segformer; pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs new file mode 100644 index 0000000000..5ebb68ae1d --- /dev/null +++ b/candle-transformers/src/models/segformer.rs @@ -0,0 +1,617 @@ +use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear}; +use candle::{Module, ModuleT, Result, Tensor, D}; +use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder}; + +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py +#[derive(Debug, Clone)] +pub struct Config { + pub num_channels: usize, + pub num_encoder_blocks: usize, + pub depths: Vec, + pub sr_ratios: Vec, + pub hidden_sizes: Vec, + pub patch_sizes: Vec, + pub strides: Vec, + pub num_attention_heads: Vec, + pub mlp_ratios: Vec, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, + pub decoder_hidden_size: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + num_channels: 3, + num_encoder_blocks: 4, + depths: vec![2, 2, 2, 2], + sr_ratios: vec![8, 4, 2, 1], + hidden_sizes: vec![32, 64, 160, 256], + patch_sizes: vec![7, 3, 3, 3], + strides: vec![4, 2, 2, 2], + num_attention_heads: vec![1, 2, 5, 8], + mlp_ratios: vec![4, 4, 4, 4], + hidden_act: candle_nn::Activation::Gelu, + layer_norm_eps: 1e-6, + decoder_hidden_size: 256, + } + } +} + +#[derive(Debug, Clone)] +struct SegformerOverlapPatchEmbeddings { + projection: Conv2d, + layer_norm: candle_nn::LayerNorm, +} + +impl SegformerOverlapPatchEmbeddings { + fn new( + config: &Config, + patch_size: usize, + stride: usize, + num_channels: usize, + hidden_size: usize, + vb: VarBuilder, + ) -> Result { + let projection = conv2d( + num_channels, + hidden_size, + patch_size, + Conv2dConfig { + stride, + padding: patch_size.div_ceil(2), + ..Default::default() + }, + vb.pp("proj"), + )?; + let layer_norm = + candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?; + Ok(Self { + projection, + layer_norm, + }) + } +} + +impl Module for SegformerOverlapPatchEmbeddings { + fn forward(&self, x: &Tensor) -> Result { + let embeddings = self.projection.forward(x)?; + let shape = embeddings.shape(); + // [B, C, H, W] -> [B, H * W, C] + let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?; + let embeddings = self.layer_norm.forward(&embeddings)?; + // [B, H * W, C] -> [B, C, H, W] + let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?; + Ok(embeddings) + } +} + +#[derive(Debug, Clone)] +struct SegformerEfficientSelfAttention { + num_attention_heads: usize, + attention_head_size: usize, + query: Linear, + key: Linear, + value: Linear, + sr: Option, + layer_norm: Option, +} + +impl SegformerEfficientSelfAttention { + fn new( + config: &Config, + hidden_size: usize, + num_attention_heads: usize, + sequence_reduction_ratio: usize, + vb: VarBuilder, + ) -> Result { + if hidden_size % num_attention_heads != 0 { + candle::bail!( + "The hidden size {} is not a multiple of the number of attention heads {}", + hidden_size, + num_attention_heads + ) + } + let attention_head_size = hidden_size.div_ceil(num_attention_heads); + let all_head_size = num_attention_heads * attention_head_size; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let (sr, layer_norm) = if sequence_reduction_ratio > 1 { + let kernel_size = sequence_reduction_ratio; + ( + Some(conv2d( + hidden_size, + hidden_size, + kernel_size, + Conv2dConfig { + stride: sequence_reduction_ratio, + ..Default::default() + }, + vb.pp("sr"), + )?), + Some(candle_nn::layer_norm( + hidden_size, + config.layer_norm_eps, + vb.pp("layer_norm"), + )?), + ) + } else { + (None, None) + }; + Ok(Self { + num_attention_heads, + attention_head_size, + query, + key, + value, + sr, + layer_norm, + }) + } + + fn transpose_for_scores(&self, hidden_states: Tensor) -> Result { + let (batch, seq_length, _) = hidden_states.shape().dims3()?; + let new_shape = &[ + batch, + seq_length, + self.num_attention_heads, + self.attention_head_size, + ]; + let hidden_states = hidden_states.reshape(new_shape)?; + let hidden_states = hidden_states.permute((0, 2, 1, 3))?; + Ok(hidden_states) + } +} + +impl Module for SegformerEfficientSelfAttention { + fn forward(&self, x: &Tensor) -> Result { + let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?; + let query = self + .transpose_for_scores(self.query.forward(&hidden_states)?)? + .contiguous()?; + let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) { + let hidden_states = sr.forward(&x)?; + let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?; + let hidden_states = layer_norm.forward(&hidden_states)?; + hidden_states + } else { + hidden_states + }; + // standard self-attention + let key = self + .transpose_for_scores(self.key.forward(&hidden_states)?)? + .contiguous()?; + let value = self + .transpose_for_scores(self.value.forward(&hidden_states)?)? + .contiguous()?; + let attention_scores = + (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?; + let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?; + let result = attention_scores.matmul(&value)?; + result + .permute((0, 2, 1, 3))? + .contiguous()? + .flatten_from(D::Minus2) + } +} + +#[derive(Debug, Clone)] +struct SegformerSelfOutput { + dense: Linear, +} + +impl SegformerSelfOutput { + fn new(hidden_size: usize, vb: VarBuilder) -> Result { + let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?; + Ok(Self { dense }) + } +} + +impl Module for SegformerSelfOutput { + fn forward(&self, x: &Tensor) -> Result { + self.dense.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerAttention { + attention: SegformerEfficientSelfAttention, + output: SegformerSelfOutput, +} + +impl SegformerAttention { + fn new( + config: &Config, + hidden_size: usize, + num_attention_heads: usize, + sequence_reduction_ratio: usize, + vb: VarBuilder, + ) -> Result { + let attention = SegformerEfficientSelfAttention::new( + config, + hidden_size, + num_attention_heads, + sequence_reduction_ratio, + vb.pp("self"), + )?; + let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?; + Ok(Self { attention, output }) + } +} + +impl Module for SegformerAttention { + fn forward(&self, x: &Tensor) -> Result { + let attention_output = self.attention.forward(x)?; + self.output.forward(&attention_output) + } +} + +#[derive(Debug, Clone)] +struct SegformerDWConv { + dw_conv: Conv2d, +} + +impl SegformerDWConv { + fn new(dim: usize, vb: VarBuilder) -> Result { + let dw_conv = conv2d( + dim, + dim, + 3, + Conv2dConfig { + stride: 1, + padding: 1, + groups: dim, + ..Default::default() + }, + vb.pp("dwconv"), + )?; + Ok(Self { dw_conv }) + } +} + +impl Module for SegformerDWConv { + fn forward(&self, x: &Tensor) -> Result { + self.dw_conv.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerMixFFN { + dense1: Linear, + dw_conv: SegformerDWConv, + act: Activation, + dense2: Linear, +} + +impl SegformerMixFFN { + fn new( + config: &Config, + in_features: usize, + hidden_features: usize, + out_features: usize, + vb: VarBuilder, + ) -> Result { + let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?; + let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?; + let act = config.hidden_act.into(); + let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?; + Ok(Self { + dense1, + dw_conv, + act, + dense2, + }) + } +} + +impl Module for SegformerMixFFN { + fn forward(&self, x: &Tensor) -> Result { + let (batch, _, height, width) = x.shape().dims4()?; + let hidden_states = self + .dense1 + .forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?; + let channels = hidden_states.dim(2)?; + let hidden_states = self.dw_conv.forward( + &hidden_states + .permute((0, 2, 1))? + .reshape((batch, channels, height, width))?, + )?; + let hidden_states = self.act.forward(&hidden_states)?; + let hidden_states = self + .dense2 + .forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?; + let channels = hidden_states.dim(2)?; + Ok(hidden_states + .permute((0, 2, 1))? + .reshape((batch, channels, height, width))?) + } +} + +#[derive(Debug, Clone)] +struct SegformerLayer { + layer_norm_1: candle_nn::LayerNorm, + attention: SegformerAttention, + layer_norm_2: candle_nn::LayerNorm, + mlp: SegformerMixFFN, +} + +impl SegformerLayer { + fn new( + config: &Config, + hidden_size: usize, + num_attention_heads: usize, + sequence_reduction_ratio: usize, + mlp_ratio: usize, + vb: VarBuilder, + ) -> Result { + let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?; + let attention = SegformerAttention::new( + config, + hidden_size, + num_attention_heads, + sequence_reduction_ratio, + vb.pp("attention"), + )?; + let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?; + let mlp = SegformerMixFFN::new( + config, + hidden_size, + hidden_size * mlp_ratio, + hidden_size, + vb.pp("mlp"), + )?; + Ok(Self { + layer_norm_1, + attention, + layer_norm_2, + mlp, + }) + } +} + +impl Module for SegformerLayer { + fn forward(&self, x: &Tensor) -> Result { + let (batch, channels, height, width) = x.shape().dims4()?; + let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?; + let hidden_states = self.layer_norm_1.forward(&hidden_states)?; + let attention_output = self.attention.forward( + &hidden_states + .permute((0, 2, 1))? + .reshape(&[batch, channels, height, width])?, + )?; + let hidden_states = (hidden_states + attention_output)?; + let hidden_states = self.layer_norm_2.forward(&hidden_states)?; + let hidden_states = hidden_states + .permute((0, 2, 1))? + .reshape((batch, channels, height, width))?; + let mlp_output = self.mlp.forward(&hidden_states)?; + hidden_states + mlp_output + } +} + +#[derive(Debug, Clone)] +struct SegformerEncoder { + /// config file + config: Config, + /// a list of embeddings + patch_embeddings: Vec, + /// a list of attention blocks, each consisting of layers + blocks: Vec>, + /// a final list of layer norms + layer_norms: Vec, +} + +impl SegformerEncoder { + fn new(config: Config, vb: VarBuilder) -> Result { + let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks); + let mut blocks = Vec::with_capacity(config.num_encoder_blocks); + let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks); + for i in 0..config.num_encoder_blocks { + let patch_size = config.patch_sizes[i]; + let stride = config.strides[i]; + let hidden_size = config.hidden_sizes[i]; + let num_channels = if i == 0 { + config.num_channels + } else { + config.hidden_sizes[i - 1] + }; + patch_embeddings.push(SegformerOverlapPatchEmbeddings::new( + &config, + patch_size, + stride, + num_channels, + hidden_size, + vb.pp(&format!("patch_embeddings.{}", i)), + )?); + let mut layers = Vec::with_capacity(config.depths[i]); + for j in 0..config.depths[i] { + let sequence_reduction_ratio = config.sr_ratios[i]; + let num_attention_heads = config.num_attention_heads[i]; + let mlp_ratio = config.mlp_ratios[i]; + layers.push(SegformerLayer::new( + &config, + hidden_size, + num_attention_heads, + sequence_reduction_ratio, + mlp_ratio, + vb.pp(&format!("block.{}.{}", i, j)), + )?); + } + blocks.push(layers); + layer_norms.push(layer_norm( + hidden_size, + config.layer_norm_eps, + vb.pp(&format!("layer_norm.{}", i)), + )?); + } + Ok(Self { + config, + patch_embeddings, + blocks, + layer_norms, + }) + } +} + +impl ModuleWithHiddenStates for SegformerEncoder { + fn forward(&self, x: &Tensor) -> Result> { + let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks); + let mut hidden_states = x.clone(); + for i in 0..self.config.num_encoder_blocks { + hidden_states = self.patch_embeddings[i].forward(&hidden_states)?; + for layer in &self.blocks[i] { + hidden_states = layer.forward(&hidden_states)?; + } + let (batch, channels, height, width) = hidden_states.shape().dims4()?; + hidden_states = + self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?; + hidden_states = hidden_states + .permute((0, 2, 1))? + .reshape((batch, channels, height, width))?; + all_hidden_states.push(hidden_states.clone()); + } + Ok(all_hidden_states) + } +} + +#[derive(Debug, Clone)] +struct SegformerModel { + encoder: SegformerEncoder, +} + +impl SegformerModel { + fn new(config: &Config, vb: VarBuilder) -> Result { + let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?; + Ok(Self { encoder }) + } +} + +impl ModuleWithHiddenStates for SegformerModel { + fn forward(&self, x: &Tensor) -> Result> { + self.encoder.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerMLP { + proj: Linear, +} + +impl SegformerMLP { + fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result { + let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?; + Ok(Self { proj }) + } +} + +impl Module for SegformerMLP { + fn forward(&self, x: &Tensor) -> Result { + self.proj.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerDecodeHead { + linear_c: Vec, + linear_fuse: candle_nn::Conv2d, + batch_norm: candle_nn::BatchNorm, + classifier: candle_nn::Conv2d, +} + +impl SegformerDecodeHead { + fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result { + let mut linear_c = Vec::with_capacity(config.num_encoder_blocks); + for i in 0..config.num_encoder_blocks { + let hidden_size = config.hidden_sizes[i]; + linear_c.push(SegformerMLP::new( + config, + hidden_size, + vb.pp(&format!("linear_c.{}", i)), + )?); + } + let linear_fuse = conv2d_no_bias( + config.decoder_hidden_size * config.num_encoder_blocks, + config.decoder_hidden_size, + 1, + Conv2dConfig::default(), + vb.pp("linear_fuse"), + )?; + let batch_norm = candle_nn::batch_norm( + config.decoder_hidden_size, + config.layer_norm_eps, + vb.pp("batch_norm"), + )?; + let classifier = conv2d_no_bias( + config.decoder_hidden_size, + num_labels, + 1, + Conv2dConfig::default(), + vb.pp("classifier"), + )?; + Ok(Self { + linear_c, + linear_fuse, + batch_norm, + classifier, + }) + } + + fn forward(&self, encoder_hidden_states: &Vec) -> Result { + if encoder_hidden_states.len() != self.linear_c.len() { + candle::bail!( + "The number of encoder hidden states {} is not equal to the number of linear layers {}", + encoder_hidden_states.len(), + self.linear_c.len() + ) + } + // most fine layer + let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?; + let mut hidden_states = Vec::with_capacity(self.linear_c.len()); + for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) { + let (batch, _, height, width) = hidden_state.shape().dims4()?; + let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?; + let hidden_state = hidden_state.permute((0, 2, 1))?.reshape(( + batch, + hidden_state.dim(2)?, + height, + width, + ))?; + let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?; + hidden_states.push(hidden_state); + } + let hidden_states = Tensor::cat(&hidden_states, 1)?; + let hidden_states = self.linear_fuse.forward(&hidden_states)?; + let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?; + let hidden_states = hidden_states.relu()?; + self.classifier.forward(&hidden_states) + } +} + +trait ModuleWithHiddenStates { + fn forward(&self, xs: &Tensor) -> Result>; +} + +#[derive(Debug, Clone)] +pub struct SemanticSegmentationModel { + segformer: SegformerModel, + decode_head: SegformerDecodeHead, +} + +impl SemanticSegmentationModel { + pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result { + let segformer = SegformerModel::new(config, vb.pp("segformer"))?; + let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?; + Ok(Self { + segformer, + decode_head, + }) + } +} + +impl Module for SemanticSegmentationModel { + fn forward(&self, x: &Tensor) -> Result { + let hidden_states = self.segformer.forward(x)?; + self.decode_head.forward(&hidden_states) + } +}