diff --git a/src/model/model.py b/src/model/model.py index 6a87b34..b60c59a 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -157,7 +157,10 @@ def forward(self, img, temperature=1.0): caption: generated caption [str] tokens: generated tokens [torch.Tensor] ''' - # only one image at a time + + if temperature <= 0.0: + temperature = 1.0 + print('Temperature must be positive. Setting it to 1.0') with torch.no_grad(): img_embedded = self.ie(img)