Skip to content

Commit

Permalink
Fix tests when PyTorch built with MPS support (#1188)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored and calebrob6 committed Apr 10, 2023
1 parent 22d633c commit d2902a5
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 18 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ filterwarnings = [
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:pretrainedmodels.datasets.utils",
# https://github.com/pytorch/vision/pull/5898
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:torchvision.transforms.functional_pil",
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:torchvision.transforms._functional_pil",
# https://github.com/rwightman/pytorch-image-models/pull/1256
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:timm.data",
# https://github.com/pytorch/pytorch/issues/72906
Expand Down Expand Up @@ -101,8 +102,11 @@ filterwarnings = [
# Expected warnings
# Lightning warns us about using num_workers=0, but it's faster on macOS
"ignore:The dataloader, .*, does not have many workers which may be a bottleneck:UserWarning",
# Lightning warns us about using the CPU when a GPU is available
# Lightning warns us about using the CPU when GPU/MPS is available
"ignore:GPU available but not used.:UserWarning",
"ignore:MPS available but not used.:UserWarning",
# Lightning warns us if TensorBoard is not installed
"ignore:Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package:UserWarning",
# https://github.com/kornia/kornia/pull/1611
"ignore:`ColorJitter` is now following Torchvision implementation.:DeprecationWarning:kornia.augmentation._2d.intensity.color_jitter",
# https://github.com/kornia/kornia/pull/1663
Expand Down
14 changes: 13 additions & 1 deletion tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def __init__(self) -> None:


class TestGeoDataModule:
@pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule])
def datamodule(self, request: SubRequest) -> CustomGeoDataModule:
dm: CustomGeoDataModule = request.param()
dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
return dm

@pytest.mark.parametrize("stage", ["fit", "validate", "test"])
def test_setup(self, stage: str) -> None:
dm = CustomGeoDataModule()
Expand Down Expand Up @@ -102,7 +108,13 @@ def test_no_datasets(self) -> None:


class TestNonGeoDataModule:
@pytest.mark.parametrize("stage", ["fit", "validate", "test"])
@pytest.fixture
def datamodule(self) -> CustomNonGeoDataModule:
dm = CustomNonGeoDataModule()
dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
return dm

@pytest.mark.parametrize("stage", ["fit", "validate", "test", "predict"])
def test_setup(self, stage: str) -> None:
dm = CustomNonGeoDataModule()
dm.prepare_data()
Expand Down
2 changes: 1 addition & 1 deletion tests/datamodules/test_oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def datamodule(self, request: SubRequest) -> OSCDDataModule:
num_workers=0,
)
dm.prepare_data()
dm.trainer = Trainer(max_epochs=1)
dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
return dm

def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ def test_trainer(
model.backbone = SegmentationTestModel(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down
42 changes: 36 additions & 6 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def test_trainer(
model = ClassificationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -192,15 +197,25 @@ def test_no_rgb(
root="tests/data/eurosat", batch_size=1, num_workers=0
)
model = ClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictClassificationDataModule(
root="tests/data/eurosat", batch_size=1, num_workers=0
)
model = ClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)


Expand Down Expand Up @@ -234,7 +249,12 @@ def test_trainer(
model = MultiLabelClassificationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -269,13 +289,23 @@ def test_no_rgb(
root="tests/data/bigearthnet", batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root="tests/data/bigearthnet", batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
21 changes: 18 additions & 3 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def test_trainer(
model = ObjectDetectionTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -131,13 +136,23 @@ def test_no_rgb(
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0
)
model = ObjectDetectionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictObjectDetectionDataModule(
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0
)
model = ObjectDetectionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
21 changes: 18 additions & 3 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ def test_trainer(
model.model = RegressionTestModel()

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -160,13 +165,23 @@ def test_no_rgb(
root="tests/data/cyclone", batch_size=1, num_workers=0
)
model = RegressionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictRegressionDataModule(
root="tests/data/cyclone", batch_size=1, num_workers=0
)
model = RegressionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
14 changes: 12 additions & 2 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ def test_trainer(
model = SemanticSegmentationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -147,5 +152,10 @@ def test_no_rgb(
root="tests/data/sen12ms", batch_size=1, num_workers=0
)
model = SemanticSegmentationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

0 comments on commit d2902a5

Please sign in to comment.