Skip to content

Commit

Permalink
fix: 🐛 add context to set_grad_enable
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Apr 15, 2023
1 parent fe95923 commit e027618
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions pyiqa/models/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,23 @@ def to(self, device):

def forward(self, target, ref=None, **kwargs):

torch.set_grad_enabled(self.as_loss)
with torch.set_grad_enabled(self.as_loss):

if 'fid' in self.metric_name:
output = self.net(target, ref, device=self.device, **kwargs)
else:
if not torch.is_tensor(target):
target = imread2tensor(target)
target = target.unsqueeze(0)
if self.metric_mode == 'FR':
assert ref is not None, 'Please specify reference image for Full Reference metric'
ref = imread2tensor(ref)
ref = ref.unsqueeze(0)

if 'fid' in self.metric_name:
output = self.net(target, ref, device=self.device, **kwargs)
else:
if not torch.is_tensor(target):
target = imread2tensor(target)
target = target.unsqueeze(0)
if self.metric_mode == 'FR':
assert ref is not None, 'Please specify reference image for Full Reference metric'
ref = imread2tensor(ref)
ref = ref.unsqueeze(0)

if self.metric_mode == 'FR':
output = self.net(target.to(self.device), ref.to(self.device), **kwargs)
elif self.metric_mode == 'NR':
output = self.net(target.to(self.device), **kwargs)
output = self.net(target.to(self.device), ref.to(self.device), **kwargs)
elif self.metric_mode == 'NR':
output = self.net(target.to(self.device), **kwargs)

if self.as_loss:
return weight_reduce_loss(output, self.loss_weight, self.loss_reduction)
Expand Down

0 comments on commit e027618

Please sign in to comment.