Skip to content

Commit

Permalink
'torch.no_grad()' is the only way to avoid backward/gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinDong committed Mar 23, 2024
1 parent 8d14ba5 commit 87949d0
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions CLIP/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ def __init__(self, model):
ids = enc.encode_ordinary(text)
ids = np.pad(ids, (0, (self.seq_len - len(ids))), "constant")
ids = torch.tensor(ids).unsqueeze(0)
cat_embds.append(self.model.txt_encoder(ids))
with torch.no_grad():
cat_embds.append(self.model.txt_encoder(ids))
self.cat_embds = torch.cat(cat_embds, dim=0)

def evaludate(self):
def evaluate(self):
correct = 0
processed = 0
idx = 0
Expand All @@ -56,7 +57,8 @@ def evaludate(self):
/ 255.0
)
image = torch.tensor(image).unsqueeze(0).permute(0, 3, 1, 2)
image_embd = self.model.img_encoder(image)
with torch.no_grad():
image_embd = self.model.img_encoder(image)
logits_per_image = image_embd @ self.cat_embds.T
# _, _max = torch.max(logits_per_image, dim=-1)
# if _max.item() == correct_index:
Expand All @@ -83,7 +85,7 @@ def __init__(self, ckpt_path, dataset):
self.dataset = self.class_map[dataset](model)

def evaluate(self):
self.dataset.evaludate()
self.dataset.evaluate()


if __name__ == "__main__":
Expand Down

0 comments on commit 87949d0

Please sign in to comment.