-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
838 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# candle-segformer | ||
|
||
- [HuggingFace Segformer Model Card][segformer] | ||
- [`mit-b0` - An encoder only pretrained model][encoder] | ||
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512] | ||
|
||
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer | ||
[encoder]: https://huggingface.co/nvidia/mit-b0 | ||
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512 | ||
|
||
## How to run | ||
|
||
```bash | ||
# run image classification task | ||
cargo run --example segformer classify candle-examples/examples/segformer/assets/burger.jpg | ||
``` | ||
|
||
Example output: | ||
|
||
```text | ||
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6] | ||
label: hamburger | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
use candle::Device; | ||
use candle::Module; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::models::segformer::{ | ||
Config, ImageClassificationModel, SemanticSegmentationModel, | ||
}; | ||
use clap::{Args, Parser, Subcommand}; | ||
use std::path::PathBuf; | ||
|
||
#[derive(Parser)] | ||
#[clap(about, version, long_about = None)] | ||
struct CliArgs { | ||
#[arg(long, help = "use cpu")] | ||
cpu: bool, | ||
#[command(subcommand)] | ||
command: Commands, | ||
} | ||
#[derive(Args, Debug)] | ||
struct SegmentationArgs { | ||
#[arg( | ||
long, | ||
help = "name of the huggingface hub model", | ||
default_value = "nvidia/segformer-b0-finetuned-ade-512-512" | ||
)] | ||
model_name: String, | ||
#[arg(help = "path to image as input")] | ||
image: PathBuf, | ||
} | ||
|
||
#[derive(Args, Debug)] | ||
struct ClassificationArgs { | ||
#[arg( | ||
long, | ||
help = "name of the huggingface hub model", | ||
default_value = "paolinox/segformer-finetuned-food101" | ||
)] | ||
model_name: String, | ||
#[arg(help = "path to image as input")] | ||
image: PathBuf, | ||
} | ||
|
||
#[derive(Subcommand, Debug)] | ||
enum Commands { | ||
Segment(SegmentationArgs), | ||
Classify(ClassificationArgs), | ||
} | ||
|
||
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> { | ||
println!("loading model {} via huggingface hub", model_name); | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let api = api.model(model_name.clone()); | ||
let model_file = api.get("model.safetensors")?; | ||
println!("model {} downloaded and loaded", model_name); | ||
let vb = | ||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? }; | ||
let config = std::fs::read_to_string(api.get("config.json")?)?; | ||
let config: Config = serde_json::from_str(&config)?; | ||
println!("{:?}", config); | ||
Ok((vb, config)) | ||
} | ||
|
||
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> { | ||
let image = candle_examples::imagenet::load_image224(args.image)? | ||
.unsqueeze(0)? | ||
.to_device(device)?; | ||
let (vb, config) = get_vb_and_config(args.model_name, device)?; | ||
let num_labels = 150; | ||
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?; | ||
let segmentations = model.forward(&image)?; | ||
println!( | ||
"segmentation result shape {:?} which should match [1, num_labels, height/4, width/4]", | ||
segmentations.shape() | ||
); | ||
let labels = segmentations.squeeze(0)?.argmax(0)?; | ||
let labels = labels.to_vec2::<u32>()?; | ||
println!("labels {:?}", labels); | ||
Ok(()) | ||
} | ||
|
||
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> { | ||
let image = candle_examples::imagenet::load_image224(args.image)? | ||
.unsqueeze(0)? | ||
.to_device(device)?; | ||
let (vb, config) = get_vb_and_config(args.model_name, device)?; | ||
let num_labels = 7; | ||
let model = ImageClassificationModel::new(&config, num_labels, vb)?; | ||
let classification = model.forward(&image)?; | ||
let classification = candle_nn::ops::softmax_last_dim(&classification)?; | ||
let classification = classification.squeeze(0)?; | ||
println!( | ||
"classification logits {:?}", | ||
classification.to_vec1::<f32>()? | ||
); | ||
let label_id = classification.argmax(0)?.to_scalar::<u32>()?; | ||
let label_id = format!("{}", label_id); | ||
println!("label: {}", config.id2label[&label_id]); | ||
Ok(()) | ||
} | ||
|
||
pub fn main() -> anyhow::Result<()> { | ||
let args = CliArgs::parse(); | ||
let device = candle_examples::device(args.cpu)?; | ||
if let Commands::Segment(args) = args.command { | ||
segmentation_task(args, &device)? | ||
} else if let Commands::Classify(args) = args.command { | ||
classification_task(args, &device)? | ||
} | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.