diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ff2b0d277..996fe9519f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,10 +10,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `deeplabv3`, `lraspp`, and `unet` backbones for the `SemanticSegmentation` task ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370)) -### Fixed - -- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343)) - ### Changed - Changed the installation command for extra features ([#346](https://github.com/PyTorchLightning/lightning-flash/pull/346)) @@ -26,6 +22,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `SemanticSegmentation` backbone names `torchvision/fcn_resnet50` and `torchvision/fcn_resnet101`, use `fc_resnet50` and `fcn_resnet101` instead ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370)) +### Fixed + +- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343)) +- Fixed a bug where using `val_split` with `overfit_batches` would give an infinite recursion ([#375](https://github.com/PyTorchLightning/lightning-flash/pull/375)) + ## [0.3.0] - 2021-05-20 diff --git a/flash/core/data/splits.py b/flash/core/data/splits.py index 8c09ad2290..451a658d92 100644 --- a/flash/core/data/splits.py +++ b/flash/core/data/splits.py @@ -13,7 +13,7 @@ class SplitDataset(Dataset): dataset: A dataset to be splitted indices: List of indices to expose from the dataset - use_duplicated_indices: Wether to allow duplicated indices. + use_duplicated_indices: Whether to allow duplicated indices. Example:: @@ -41,9 +41,9 @@ def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices self.indices = indices def __getattr__(self, key: str): - if key in self._INTERNAL_KEYS: - return getattr(self, key) - return getattr(self.dataset, key) + if key not in self._INTERNAL_KEYS: + return self.dataset.__getattribute__(key) + raise AttributeError def __setattr__(self, name: str, value: Any) -> None: if name in self._INTERNAL_KEYS: diff --git a/tests/core/data/test_split_dataset.py b/tests/core/data/test_splits.py similarity index 82% rename from tests/core/data/test_split_dataset.py rename to tests/core/data/test_splits.py index 0a450d7ad6..14e7f12993 100644 --- a/tests/core/data/test_split_dataset.py +++ b/tests/core/data/test_splits.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import numpy as np import pytest from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -19,25 +21,12 @@ from flash.core.data.splits import SplitDataset -def test_split_dataset(tmpdir): - +def test_split_dataset(): train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1) assert len(train_ds) == 90 assert len(val_ds) == 10 assert len(np.unique(train_ds.indices)) == len(train_ds.indices) - with pytest.raises(MisconfigurationException, match="[0, 99]"): - SplitDataset(range(100), indices=[100]) - - with pytest.raises(MisconfigurationException, match="[0, 49]"): - SplitDataset(range(50), indices=[-1]) - - with pytest.raises(MisconfigurationException, match="[0, 49]"): - SplitDataset(list(range(50)) + list(range(50)), indices=[-1]) - - with pytest.raises(MisconfigurationException, match="[0, 99]"): - SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True) - class Dataset: def __init__(self): @@ -57,3 +46,27 @@ def __len__(self): split_dataset.is_passed_down = True assert split_dataset.dataset.is_passed_down + + +def test_misconfiguration(): + with pytest.raises(MisconfigurationException, match="[0, 99]"): + SplitDataset(range(100), indices=[100]) + + with pytest.raises(MisconfigurationException, match="[0, 49]"): + SplitDataset(range(50), indices=[-1]) + + with pytest.raises(MisconfigurationException, match="[0, 49]"): + SplitDataset(list(range(50)) + list(range(50)), indices=[-1]) + + with pytest.raises(MisconfigurationException, match="[0, 99]"): + SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True) + + with pytest.raises(MisconfigurationException, match="indices should be a list"): + SplitDataset(list(range(100)), indices="not a list") + + +def test_deepcopy(): + """Tests that deepcopy works with the ``SplitDataset``.""" + dataset = list(range(100)) + train_ds, val_ds = DataModule._split_train_val(dataset, val_split=0.1) + deepcopy(train_ds)