Skip to content

Commit

Permalink
feat: 🔧 set seed for every forward in test mode
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Apr 30, 2024
1 parent 7d1d30f commit da26e6d
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyiqa/models/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ def __init__(
self.net = self.net.to(self.device)
self.net.eval()

self.seed = seed
if not as_loss:
set_random_seed(seed)

self.dummy_param = torch.nn.Parameter(torch.empty(0)).to(self.device)

def forward(self, target, ref=None, **kwargs):
device = self.dummy_param.device
if not self.as_loss:
set_random_seed(self.seed)

with torch.set_grad_enabled(self.as_loss):

Expand Down

0 comments on commit da26e6d

Please sign in to comment.