Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

from_tensors support for VideoClassification #1389

Merged
merged 48 commits into from
Sep 1, 2022

Conversation

krshrimali
Copy link
Contributor

@krshrimali krshrimali commented Jul 14, 2022

What does this PR do?

Addresses #1356. A sample script to try it out:

import torch
from flash.video import VideoClassifier, VideoClassificationData
import flash

# 5 number of frames, 3 channels, height = 10 and width = 10
mock_tensors = torch.randint(size=(3, 5, 10, 10), low=0, high=255)
datamodule = VideoClassificationData.from_tensors(
    train_data=[mock_tensors, mock_tensors],  # can also stack: torch.stack((mock_tensors, mock_tensors))
    train_targets=["patient", "doctor"],
    predict_data=[mock_tensors],
    batch_size=1,
)

model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50", labels=datamodule.labels)
trainer = flash.Trainer(fast_dev_run=True)
trainer.finetune(model, datamodule=datamodule)

Output:

  | Name          | Type       | Params
---------------------------------------------
0 | train_metrics | ModuleDict | 0
1 | val_metrics   | ModuleDict | 0
2 | test_metrics  | ModuleDict | 0
3 | backbone      | Net        | 32.5 M
4 | head          | Sequential | 802
---------------------------------------------
32.5 M    Trainable params
0         Non-trainable params
32.5 M    Total params
129.820   Total estimated model params size (MB)
/home/krshrimali/anaconda3/envs/pl-17/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 24 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/krshrimali/anaconda3/envs/pl-17/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1938: PossibleUserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 0: 100%|| 1/1.0 [00:01<00:00,  1.60s/it, loss=0.744, v_num=, train_accuracy_step=0.000, train_cross_entropy_step=0.744,

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented Jul 14, 2022

Codecov Report

Merging #1389 (1cf3d75) into master (0253d71) will decrease coverage by 0.06%.
The diff coverage is 82.92%.

@@            Coverage Diff             @@
##           master    #1389      +/-   ##
==========================================
- Coverage   92.90%   92.83%   -0.07%     
==========================================
  Files         286      287       +1     
  Lines       12874    12969      +95     
==========================================
+ Hits        11960    12040      +80     
- Misses        914      929      +15     
Flag Coverage Δ
unittests 92.83% <82.92%> (-0.07%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/video/classification/input.py 90.37% <76.74%> (-6.48%) ⬇️
flash/video/classification/utils.py 87.87% <87.87%> (ø)
flash/video/classification/data.py 100.00% <100.00%> (ø)
flash/core/finetuning.py 88.23% <0.00%> (-2.47%) ⬇️
flash/core/utilities/imports.py 91.47% <0.00%> (+0.04%) ⬆️
flash/core/classification.py 95.75% <0.00%> (+0.60%) ⬆️
flash/core/serve/dag/task.py 97.88% <0.00%> (+1.05%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@krshrimali krshrimali marked this pull request as draft July 14, 2022 09:24
Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, few comments 😃

flash/video/classification/utils.py Outdated Show resolved Hide resolved
flash/video/classification/utils.py Outdated Show resolved Hide resolved
flash/video/classification/utils.py Outdated Show resolved Hide resolved
@ethanwharris ethanwharris self-assigned this Aug 30, 2022
krshrimali and others added 4 commits August 30, 2022 15:43
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
…m:Lightning-AI/lightning-flash into video/feature/classification/from_tensors
Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, small suggestion

flash/video/classification/utils.py Outdated Show resolved Hide resolved
krshrimali and others added 2 commits August 30, 2022 15:56
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just a comment on the API

Comment on lines 643 to 649
>>> datamodule = VideoClassificationData.from_tensors(
... input_field="data",
... target_field="targets",
... train_data=train_data,
... predict_data=predict_data,
... batch_size=1,
... )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to make the API consistent with what we have for image classification. E.g. like this:

            >>> datamodule = VideoClassificationData.from_tensors(
            ...     train_data=[input_video, input_video, input_video],
            ...     train_targets=[1, 2, 3],
            ...     predict_data=predict_data,
            ...     batch_size=1,
            ... )

@ethanwharris ethanwharris self-requested a review August 30, 2022 17:03
Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just requesting changes so we don't accidentally merge it before agreeing on the API 😃

Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 😃

@ethanwharris ethanwharris merged commit 4674aba into master Sep 1, 2022
@ethanwharris ethanwharris deleted the video/feature/classification/from_tensors branch September 1, 2022 16:46
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants