Skip to content

Commit

Permalink
Merge pull request #27 from Laughing-q/instance_seg
Browse files Browse the repository at this point in the history
update F.interpolate&&clean up
  • Loading branch information
AyushExel authored Sep 12, 2022
2 parents 6ddcf5a + 5a9d410 commit 5905c76
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 4 deletions.
2 changes: 1 addition & 1 deletion segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def run(
vid_writer[i].write(im0)

# Print time (inference-only)
# LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

# Print results
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
Expand Down
2 changes: 0 additions & 2 deletions segment/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# return

# Mosaic plots
if mask_ratio != 1:
masks = F.interpolate(masks[None].float(), (imgsz, imgsz), mode="bilinear", align_corners=False)[0]
if plots:
if ni < 3:
plot_images_and_masks(imgs, targets, masks, paths, save_dir / f"train_batch{ni}.jpg")
Expand Down
2 changes: 1 addition & 1 deletion utils/segment/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __call__(self, preds, targets, masks): # predictions, targets, model

# Mask regression
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="bilinear", align_corners=False)[0]
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
marea = xywhn[i][:, 2:].prod(1) # mask width, height normalized
mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device))
for bi in b.unique():
Expand Down

0 comments on commit 5905c76

Please sign in to comment.