Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

write out prediction json at test step during training #47

Merged
32 changes: 32 additions & 0 deletions tcn_hpl/callbacks/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_lightning.utilities.types import STEP_OUTPUT
from sklearn.metrics import confusion_matrix
import torch
import kwcoco

try:
from aim import Image
Expand Down Expand Up @@ -345,6 +346,7 @@ def on_test_batch_end(
self._val_all_targets.append(outputs["targets"].cpu())
self._val_all_source_vids.append(outputs["source_vid"].cpu())
self._val_all_source_frames.append(outputs["source_frame"].cpu())
self._preds_dset_output_fpath = self.output_dir / "tcn_activity_predictions.kwcoco.json"

def on_test_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
Expand All @@ -362,6 +364,36 @@ def on_test_epoch_end(
test_acc = pl_module.test_acc.compute()
test_f1 = pl_module.test_f1.compute()

# Create activity predictions KWCOCO JSON
truth_dset_fpath = trainer.datamodule.hparams["coco_test_activities"]
truth_dset = kwcoco.CocoDataset(truth_dset_fpath)
acts_dset = kwcoco.CocoDataset()
acts_dset.fpath = self._preds_dset_output_fpath
acts_dset.dataset['videos'] = truth_dset.dataset['videos']
acts_dset.dataset['images'] = truth_dset.dataset['images']
acts_dset.dataset['categories'] = truth_dset.dataset['categories']
acts_dset.index.build(acts_dset)
# Create numpy lookup tables
for i in range(len(all_preds)):
frame_index = all_source_frames[i].item()
video_id = all_source_vids[i].item()
# Now get the image_id that matches the frame_index and video_id.
sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)]
image_id = sorted_img_ids_for_one_video[frame_index]
# Sanity check: this image_id corresponds to the frame_index and video_id
assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index
assert acts_dset.index.imgs[image_id]['video_id'] == video_id

acts_dset.add_annotation(
image_id=image_id,
category_id=all_preds[i].item(),
score=all_probs[i][all_preds[i]].item(),
prob=all_probs[i].numpy().tolist(),
)
print(f"Dumping activities file to {acts_dset.fpath}")
acts_dset.dump(acts_dset.fpath, newlines=True)


#
# Plot per-video class predictions vs. GT across progressive frames in
# that video.
Expand Down
Loading