Skip to content

Commit

Permalink
Add Gaudi hpu accelerator option to BaseCAM (#547)
Browse files Browse the repository at this point in the history
* add gaudi hpu option

Signed-off-by: Daniel Deleon <daniel.de.leon@intel.com>

* add try except block

Signed-off-by: Daniel Deleon <daniel.de.leon@intel.com>

---------

Signed-off-by: Daniel Deleon <daniel.de.leon@intel.com>
  • Loading branch information
daniel-de-leon-user293 authored Dec 17, 2024
1 parent 5cef718 commit a2a23f8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ two smoothing methods are supported:
Usage: `python cam.py --image-path <path_to_image> --method <method> --output-dir <output_dir_path> `


To use with a specific device, like cpu, cuda, cuda:0 or mps:
To use with a specific device, like cpu, cuda, cuda:0, mps or hpu:
`python cam.py --image-path <path_to_image> --device cuda --output-dir <output_dir_path> `

----------
Expand Down
3 changes: 3 additions & 0 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def get_args():
'kpcacam': KPCA_CAM
}

if args.device=='hpu':
import habana_frameworks.torch.core as htcore

model = models.resnet50(pretrained=True).to(torch.device(args.device)).eval()

# Choose the target layer you want to compute the visualization for.
Expand Down
9 changes: 9 additions & 0 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ def __init__(

# Use the same device as the model.
self.device = next(self.model.parameters()).device
if 'hpu' in str(self.device):
try:
import habana_frameworks.torch.core as htcore
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."
raise error
self.__htcore = htcore
self.reshape_transform = reshape_transform
self.compute_input_gradient = compute_input_gradient
self.uses_gradients = uses_gradients
Expand Down Expand Up @@ -97,6 +104,8 @@ def forward(
self.model.zero_grad()
loss = sum([target(output) for target, output in zip(targets, outputs)])
loss.backward(retain_graph=True)
if 'hpu' in str(self.device):
self.__htcore.mark_step()

# In most of the saliency attribution papers, the saliency is
# computed with a single target layer.
Expand Down

0 comments on commit a2a23f8

Please sign in to comment.