Skip to content

Commit

Permalink
Readme
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobgil committed Aug 18, 2024
1 parent 7ff4d2e commit da83f1f
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ This can be useful if you're not sure what layer will perform best.

----------

# Using from code as a library
# Usage examples

```python
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
Expand All @@ -124,34 +124,27 @@ target_layers = [model.layer4[-1]]
input_tensor = # Create an input tensor image for your model..
# Note: input_tensor can be a batch tensor with several images!

# Construct the CAM object once, and then re-use it on many images:
cam = GradCAM(model=model, target_layers=target_layers)

# You can also use it within a with statement, to make sure it is freed,
# In case you need to re-create it inside an outer loop:
# with GradCAM(model=model, target_layers=target_layers) as cam:
# ...

# We have to specify the target we want to generate
# the Class Activation Maps for.
# If targets is None, the highest scoring category
# will be used for every image in the batch.
# Here we use ClassifierOutputTarget, but you can define your own custom targets
# That are, for example, combinations of categories, or specific outputs in a non standard model.

targets = [ClassifierOutputTarget(281)]

# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
# Construct the CAM object once, and then re-use it on many images.
with GradCAM(model=model, target_layers=target_layers) as cam:

# In this example grayscale_cam has only one image in the batch:
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

# You can also get the model outputs without having to re-inference
model_outputs = cam.outputs
# In this example grayscale_cam has only one image in the batch:
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

# You can also get the model outputs without having to redo inference
model_outputs = cam.outputs
```

Cam.py has a more detailed usage example.

----------

# Metrics and evaluating the explanations
Expand Down Expand Up @@ -179,18 +172,42 @@ from pytorch_grad_cam.metrics.road import ROADMostRelevantFirstAverage,
cam_metric = ROADCombined(percentiles=[20, 40, 60, 80])
scores = cam_metric(input_tensor, grayscale_cams, targets, model)
```

----------


# Advanced use cases and tutorials:

You can use this package for "custom" deep learning models, for example Object Detection or Semantic Segmentation.
Methods like GradCAM were designed for and were originally mostly applied on classification models,
and specifically CNN classification models.
However you can also use this package on new architectures like Vision Transformers, and on non classification tasks like Object Detection or Semantic Segmentation.

The be able to adapt to non standard cases, we have two concepts.
- The reshape transform - how do we convert activations to represent spatial images ?
- The model targets - What exactly should the explainability method try to explain ?

## The reshape transform
In a CNN the intermediate activations in the model are a mult-channel image that have the dimensions channel x rows x cols,
and the various explainabiltiy methods work with these to produce a new image.

In case of another architecture, like the Vision Transformer, the shape might be different, like (rows x cols + 1) x channels, or something else.
The reshape transform converts the activations back into a multi-channel image, for example by removing the class token in a vision transformer.
For examples, check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/reshape_transforms.py)

## Model Targets
The model target is just a callable that is able to get the model output, and filter it out for the specific scalar output we want to explain.

For classification tasks, the model target will typically be the output from a specific category.
The `targets` parameter passed to the CAM method can then use `ClassifierOutputTarget`:
```python
targets = [ClassifierOutputTarget(281)]
```

However more advanced cases, you might want another behaviour.
Check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py) for more examples.

You will have to define objects that you can then pass to the CAM algorithms:
1. A reshape_transform, that aggregates the layer outputs into 2D tensors that will be displayed.
2. Model Targets, that define what target do you want to compute the visualizations for, for example a specific category, or a list of bounding boxes.

# Tutorials
Here you can find detailed examples of how to use this for various custom use cases like object detection:

These point to the new documentation jupter-book for fast rendering.
Expand Down

0 comments on commit da83f1f

Please sign in to comment.