Skip to content

A distilled Segment Anything (SAM) model capable of running real-time with NVIDIA TensorRT and edge devices

License

Notifications You must be signed in to change notification settings

binh234/nanosam

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

NanoSAM

πŸ‘ Usage - ⏱️ Performance - πŸ› οΈ Setup - 🀸 Examples - πŸ‹οΈ Training
- 🧐 Evaluation - πŸ‘ Acknowledgment - πŸ”— See also

NanoSAM is a Segment Anything (SAM) and EfficientViT-SAM model variant that is developed to target πŸ”₯CPU, mobile, and edge πŸ”₯ deployment such as NVIDIA Jetson Xavier Platforms with NVIDIA TensorRT.

A demo of NanoSAM running on CPU is open at hugging face demo. On our own i5-8265U CPU, it only takes around 0.3s. On the hugging face demo, the interface and inferior CPUs make it a little bit slower but still works fine. You can also run a demo of NanoSAM on your local PC

NanoSAM is trained by distilling the EfficientViT-SAM-L0 image encoder on unlabeled images using 6% of SA-1B dataset. For an introduction to knowledge distillation, we recommend checking out this tutorial.

πŸ‘ Usage

Using NanoSAM from Python looks like this. You can also see the notebook here for more details

from nanosam.utils.predictor import Predictor

image_encoder_cfg = {
    "path": "data/sam_hgv2_b4_ln_nonorm_image_encoder.onnx",
    "name": "OnnxModel",
    "provider": "cpu",
    "normalize_input": False,
}
mask_decoder_cfg = {
    "path": "data/efficientvit_l0_mask_decoder.onnx",
    "name": "OnnxModel",
    "provider": "cpu",
}
predictor = Predictor(encoder_cfg, decoder_cfg)

image = PIL.Image.open("assets/dogs.jpg")

predictor.set_image(image)

mask, _, _ = predictor.predict(np.array([[x, y]]), np.array([1]))
Notes The point labels may be
Point Label Description
0 Background point
1 Foreground point
2 Bounding box top-left
3 Bounding box bottom-right

Follow the instructions below for how to build the engine files.

⏱️ Performance

NanoSAM performance on edge devices.

Model † ⏱️ CPU (ms) ⏱️ Jetson Xavier NX (ms) ⏱️ T4 (ms) Model Size Link
Image Encoder Full Pipeline Image Encoder Full Pipeline Image Encoder Full Pipeline
PPHGV2-SAM-B1 110ms 180ms 9.6ms 17ms 2.4ms 5.8ms 12.1MB Link
PPHGV2-SAM-B2 200ms 270ms 12.4ms 19.8ms 3.2ms 6.4ms 28.1MB Link
PPHGV2-SAM-B4 300ms 370ms 17.3ms 24.7ms 4.1ms 7.5ms 58.6MB Link
NanoSAM (ResNet18) 500ms 570ms 22.4ms 29.8ms 5.8ms 9.2ms 60.4MB Link
EfficientViT-SAM-L0 1s 1.07s 31.6ms 38ms 6ms 9.4ms 117.4MB
Notes
  • For CPU benchmark, latency/throughput is measured on Intel(R) Core(TM) i5-8265U CPU @ 1.60GHz using ONNX CPUExecutionProvider

  • For GPU benchmark, latency/throughput is measured on NVIDIA Jetson Xavier NX, and NVIDIA T4 GPU with TensorRT, fp16. Data transfer time is included.

Zero-Shot Instance Segmentation on COCO2017 validation dataset

Image Encoder mAPmask
50-95
mIoU (all) mIoU (large) mIoU (medium) mIoU (small)
ResNet18 - 70.6 79.6 73.8 62.4
MobileSAM - 72.8 80.4 75.9 65.8
PPHGV2-B1 41.2 75.6 81.2 77.4 70.8
PPHGV2-B2 42.6 76.5 82.2 78.5 71.5
PPHGV2-B4 44.0 77.3 83.0 79.7 72.1
EfficientViT-L0 45.6 78.6 83.7 81.0 73.3
Notes
  • To conduct box-prompted instance segmentation, you must first obtain the source_json_file of detected bounding boxes. Follow the instructions of ViTDet, YOLOv8, and GroundingDINO to get the source_json_file. You can also download pre-generated files.

  • mIoU is computed by prompting SAM with ground-truth object bounding box annotations from the COCO 2017 validation dataset. The IoU is then computed between the mask output of the SAM model for the object and the ground-truth COCO segmentation mask for the object. The mIoU is the average IoU over all objects in the COCO 2017 validation set matching the target object size (small, medium, large).

  • mAPmask
    50-95 is computed by prompting SAM with ViTDet's predicted bounding boxes on COCO 2017 validation dataset.

πŸ› οΈ Setup

NanoSAM is fairly easy to get started with.

  1. Install the dependencies

    1. Install ONNX + ONNXRuntime for ONNX deployment or Pytorch + NVIDIA TensorRT for TensorRT deployment

    2. (optional) Install TRTPose - For the pose example.

      git clone https://github.com/NVIDIA-AI-IOT/trt_pose
      cd trt_pose
      python3 setup.py develop --user
    3. (optional) Install the Transformers library - For the OWL ViT example.

      python3 -m pip install transformers
  2. Install the NanoSAM Python package

    git clone https://github.com/NVIDIA-AI-IOT/nanosam
    cd nanosam
    pip install -e .
  3. Download the EfficientViT-SAM-L0 mask decoder ONNX file from here to data folder

  4. Download the image encoder: sam_hgv2_b4_ln_nonorm_image_encoder.onnx to data folder

  5. Run the basic usage example

    python3 scripts/basic_usage.py \
        --encoder_cfg configs/inference/encoder.yaml \
        --decoder_cfg configs/inference/decoder.yaml

    This outputs a result to data/basic_usage_out.jpg

That's it! From there, you can read the example code for examples on how to use NanoSAM with Python. Or try running the more advanced examples below.

Build TensorRT engines

If you want to run SAM with TensorRT FP16, please follow these steps

  1. Mask decoder

    trtexec \
        --onnx=data/efficientvit_l0_mask_decoder.onnx \
        --saveEngine=data/efficientvit_l0_mask_decoder.engine \
        --minShapes=point_coords:1x1x2,point_labels:1x1 \
        --optShapes=point_coords:1x2x2,point_labels:1x2 \
        --maxShapes=point_coords:1x4x2,point_labels:1x4 \
        --fp16

    This assumes the mask decoder ONNX file is downloaded to data/efficientvit_l0_mask_decoder.onnx

    Notes This command builds the engine to support up to 4 keypoints. You can increase this limit as needed by specifying a different max shape.
  2. Image encoder

    1. For TensorRT >= 8.6
    trtexec \
        --onnx=data/sam_hgv2_b4_ln_nonorm_image_encoder.onnx \
        --saveEngine=data/sam_hgv2_b4_ln_nonorm_image_encoder.engine \
        --fp16
    1. For TensorRT < 8.6

    The layernorm op causes overflow when running in FP16, you need to force all its operation to run in FP32 precision to preserve the model accuracy

    • Download alternative models from here that are exported in opset 11
    • Install Polygraphy
    python -m pip install colored polygraphy --extra-index-url https://pypi.ngc.nvidia.com
    • Build TensorRT engine
    PPHGV2 B4
    polygraphy convert data/sam_hgv2_b4_nonorm_image_encoder.onnx \
        -o data/sam_pphgv2_b4_nonorm_image_encoder.engine \
        --precision-constraints obey \
        --layer-precisions p2o.ReduceMean.2:float32 p2o.Sub.0:float32 p2o.Mul.0:float32 p2o.Add.14:float32 p2o.Sqrt.0:float32 p2o.Div.0:float32 p2o.Mul.1:float32 p2o.Add.16:float32 \
        --fp16 --pool-limit workspace:2G
    PPHGV2 B2
    polygraphy convert data/sam_hgv2_b2_nonorm_image_encoder.onnx \
        -o data/sam_pphgv2_b2_nonorm_image_encoder.engine \
        --precision-constraints obey \
        --layer-precisions p2o.ReduceMean.2:float32 p2o.Sub.0:float32 p2o.Mul.82:float32 p2o.Add.96:float32 p2o.Sqrt.0:float32 p2o.Div.0:float32 p2o.Mul.83:float32 p2o.Add.98:float32 \
        --fp16 --pool-limit workspace:2G
    PPHGV2 B1
    polygraphy convert data/sam_hgv2_b1_nonorm_image_encoder.onnx \
        -o data/sam_pphgv2_b1_nonorm_image_encoder.engine \
        --fp16 \
        --precision-constraints obey \
        --layer-precisions p2o.ReduceMean.2:float32 p2o.Sub.0:float32 p2o.Mul.60:float32 p2o.Add.72:float32 p2o.Sqrt.0:float32 p2o.Div.0:float32 p2o.Mul.61:float32 p2o.Add.74:float32 \
        --fp16 --pool-limit workspace:2G
  3. Run inference

    Change the path inside config files to point to the correct engine

    python3 scripts/basic_usage.py \
        --encoder_cfg configs/inference/encoder_tensorrt.yaml \
        --decoder_cfg configs/inference/decoder_tensorrt.yaml

🀸 Examples

NanoSAM can be applied in many creative ways.

Example 1 - Segment with bounding box

This example uses a known image with a fixed bounding box to control NanoSAM segmentation.

To run the example, call

python3 scripts/basic_usage.py \
    --encoder_cfg configs/inference/encoder.yaml \
    --decoder_cfg configs/inference/decoder.yaml

Example 2 - Segment with bounding box (using OWL-ViT detections)

This example demonstrates using OWL-ViT to detect objects using a text prompt(s), and then segmenting these objects using NanoSAM.

To run the example, call

python3 scripts/segment_from_owl.py \
    assets/john_1.jpg \
    --prompt="A tree" \
    --encoder_cfg configs/inference/encoder.yaml \
    --decoder_cfg configs/inference/decoder.yaml
Notes - While OWL-ViT does not run real-time on Jetson Orin Nano (3sec/img), it is nice for experimentation as it allows you to detect a wide variety of objects. You could substitute any other real-time pre-trained object detector to take full advantage of NanoSAM's speed.

Example 3 - Segment with keypoints (offline using TRTPose detections)

This example demonstrates how to use human pose keypoints from TRTPose to control NanoSAM segmentation.

To run the example, call

python3 scripts/segment_from_pose.py

This will save an output figure to data/segment_from_pose_out.png.

Example 4 - Segment with keypoints (online using TRTPose detections)

This example demonstrates how to use human pose to control segmentation on a live camera feed. This example requires an attached display and camera.

To run the example, call

python3 scripts/demo_pose_tshirt.py

Example 5 - Segment and track (experimental)

This example demonstrates a rudimentary segmentation tracking with NanoSAM. This example requires an attached display and camera.

To run the example, call

python3 scripts/demo_click_segment_track.py \
    --encoder_cfg configs/inference/encoder.yaml \
    --decoder_cfg configs/inference/decoder.yaml

Once the example is running double click an object you want to track.

Notes This tracking method is very simple and can get lost easily. It is intended to demonstrate creative ways you can use NanoSAM, but would likely be improved with more work.

πŸ‹οΈ Training

Training code will be available soon

🧐 Evaluation

You can reproduce the accuracy results above by evaluating against COCO ground truth masks

  1. Download and extract the COCO 2017 validation set.

    # mkdir -p data/coco  # uncomment if it doesn't exist
    cd data/coco
    wget http://images.cocodataset.org/zips/val2017.zip
    wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
    unzip val2017.zip
    unzip annotations_trainval2017.zip
    cd ../..
  2. Compute the IoU of NanoSAM mask predictions against the ground truth COCO mask annotation.

    python3 -m nanosam.tools.eval_coco \
        --data_root=data/coco \
        --img_dir=val2017 \
        --ann_file=annotations/instances_val2017.json \
        --encoder_cfg configs/inference/encoder.yaml \
        --decoder_cfg configs/inference/decoder.yaml
        --output=data/hgv2_b4_coco_results.json

    This uses the COCO ground-truth bounding boxes as inputs to NanoSAM

  3. Compute the average IoU over a selected category or size

    python3 -m nanosam.tools.compute_eval_coco_metrics \
        data/hgv2_b4_coco_results.json \
        --size="all"
    Notes For all options type ``python3 -m nanosam.tools.compute_eval_coco_metrics --help``.

    To compute the mIoU for a specific category id.

    python3 -m nanosam.tools.compute_eval_coco_metrics \
        data/hgv2_b4_coco_results.json \
        --category_id=1

πŸ‘ Acknowledgement

This project is enabled by the great projects below.

  • SAM - The original Segment Anything model.
  • EfficientViT-SAM - Fast and efficient Segment Anything model.
  • MobileSAM - The distilled Tiny ViT Segment Anything model.

πŸ”— See also

About

A distilled Segment Anything (SAM) model capable of running real-time with NVIDIA TensorRT and edge devices

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 89.2%
  • Python 10.8%