Skip to content

Commit

Permalink
Remove default augs
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Jul 9, 2022
1 parent 8178d74 commit 1072fd5
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
from ..samplers.utils import _to_tuple
from .utils import dataset_split

DEFAULT_AUGS = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["input", "mask"],
)


def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Flatten wrapper."""
Expand Down Expand Up @@ -55,7 +49,6 @@ def __init__(
test_split_pct: float = 0.1,
patch_size: Union[int, Tuple[int, int]] = 512,
num_patches_per_tile: int = 32,
augmentations: K.AugmentationSequential = DEFAULT_AUGS,
predict_on: str = "test",
) -> None:
"""Initialize a LightningDataModule for InriaAerialImageLabeling based DataLoaders.
Expand All @@ -70,7 +63,6 @@ def __init__(
test_split_pct: What percentage of the dataset to use as a test set
patch_size: Size of random patch from image and mask (height, width)
num_patches_per_tile: Number of random patches per sample
augmentations: Default augmentations applied
predict_on: Directory/Dataset of images to run inference on
"""
super().__init__() # type: ignore[no-untyped-call]
Expand All @@ -81,7 +73,11 @@ def __init__(
self.test_split_pct = test_split_pct
self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size))
self.num_patches_per_tile = num_patches_per_tile
self.augmentations = augmentations
self.augmentations = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["input", "mask"],
)
self.predict_on = predict_on
self.random_crop = K.AugmentationSequential(
K.RandomCrop(self.patch_size, p=1.0, keepdim=False),
Expand Down

0 comments on commit 1072fd5

Please sign in to comment.