Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix Video DDP (#1189)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Feb 23, 2022
1 parent 0cd7bb6 commit de4e856
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 10 deletions.
1 change: 1 addition & 0 deletions .azure-pipelines/gpu-example-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
- "image"
- "text"
- "tabular"
- "video"
gpu_inds:
- "0"
- "0,1"
2 changes: 1 addition & 1 deletion .azure-pipelines/testing-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- bash: |
python -c "import torch; print(f'found GPUs: {torch.cuda.device_count()}')"
python -m coverage run --source flash -m pytest flash tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
python -m coverage run --source flash -m pytest tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
env:
CUDA_VISIBLE_DEVICES: ${{gids}}
displayName: 'Testing'
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where DDP would not work with Flash tasks ([#1182](https://github.com/PyTorchLightning/lightning-flash/pull/1182))

- Fixed DDP support for `VideoClassifier` ([#1189](https://github.com/PyTorchLightning/lightning-flash/pull/1189))

## [0.7.0] - 2022-02-15

### Added
Expand Down
1 change: 0 additions & 1 deletion flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(self, *args, **kwargs):
if flash._IS_TESTING:
if torch.cuda.is_available():
kwargs["gpus"] = -1
kwargs["max_epochs"] = 3
kwargs["limit_train_batches"] = 1.0
kwargs["limit_val_batches"] = 1.0
kwargs["limit_test_batches"] = 1.0
Expand Down
9 changes: 2 additions & 7 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def __init__(

def on_train_start(self) -> None:
if accelerator_connector(self.trainer).is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.data
encoded_dataset._video_sampler = DistributedSampler(encoded_dataset._labeled_videos)
super().on_train_start()

def on_train_epoch_start(self) -> None:
if accelerator_connector(self.trainer).is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.data
encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch)
super().on_train_epoch_start()

Expand All @@ -147,8 +147,3 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
def modules_to_freeze(self) -> Union[nn.Module, Iterable[Union[nn.Module, Iterable]]]:
"""Return the module attributes of the model to be frozen."""
return list(self.backbone.children())

@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
assert history[-1]["val_accuracy"] > 0.70
4 changes: 3 additions & 1 deletion flash_examples/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
model = VideoClassifier(backbone="x3d_xs", labels=datamodule.labels, pretrained=False)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count(), fast_dev_run=True)
trainer = flash.Trainer(
max_epochs=1, gpus=torch.cuda.device_count(), strategy="ddp" if torch.cuda.device_count() > 1 else None
)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Make a prediction
Expand Down

0 comments on commit de4e856

Please sign in to comment.