diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dd7c35377..4e912fec6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,11 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- Added adaptive handling for last training minibatch of 1-2 cells in case of + `datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into + validation set, if available. + {pr}`3036`. + #### Fixed - Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI` diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index a33748d40c..be7c557468 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -21,7 +21,14 @@ from scvi.utils._docstrings import devices_dsp -def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None): +def validate_data_split( + n_samples: int, + train_size: float, + validation_size: float | None = None, + batch_size: int | None = None, + drop_last: bool | int = False, + train_size_is_none: bool | int = True, +): """Check data splitting parameters and return n_train and n_val. Parameters @@ -32,21 +39,18 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa Size of train set. Need to be: 0 < train_size <= 1. validation_size Size of validation set. Need to be 0 <= validation_size < 1 + batch_size + batch size of each iteration. If `None`, do not minibatch + drop_last + drops last non-full batch + train_size_is_none + Whether the user did not explicitly input train_size """ if train_size > 1.0 or train_size <= 0.0: raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1") n_train = ceil(train_size * n_samples) - if n_train % settings.batch_size < 3 and n_train % settings.batch_size > 0: - warnings.warn( - f"Last batch will have a small size of {n_train % settings.batch_size}" - f"samples. Consider changing settings.batch_size or batch_size in model.train" - f"currently {settings.batch_size} to avoid errors during model training.", - UserWarning, - stacklevel=settings.warnings_stacklevel, - ) - if validation_size is None: n_val = n_samples - n_train elif validation_size >= 1.0 or validation_size < 0.0: @@ -59,16 +63,40 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa if n_train == 0: raise ValueError( f"With n_samples={n_samples}, train_size={train_size} and " - f"validation_size={validation_size}, the resulting train set will be empty. Adjust" + f"validation_size={validation_size}, the resulting train set will be empty. Adjust " "any of the aforementioned parameters." ) + if batch_size is not None: + num_of_cells = n_train % batch_size + if (num_of_cells < 3 and num_of_cells > 0) and drop_last is False: + if not train_size_is_none: + warnings.warn( + f"Last batch will have a small size of {num_of_cells} " + f"samples. Consider changing settings.batch_size or batch_size in model.train " + f"from currently {batch_size} to avoid errors during model training.", + UserWarning, + stacklevel=settings.warnings_stacklevel, + ) + else: + n_train -= num_of_cells + if n_val > 0: + n_val += num_of_cells + warnings.warn( + f"{num_of_cells} cells moved from training set to validation set." + f" if you want to avoid it please use train_size parameter during train.", + UserWarning, + stacklevel=settings.warnings_stacklevel, + ) + return n_train, n_val def validate_data_split_with_external_indexing( n_samples: int, external_indexing: list[np.array, np.array, np.array] | None = None, + batch_size: int | None = None, + drop_last: bool | int = False, ): """Check data splitting parameters and return n_train and n_val. @@ -79,6 +107,10 @@ def validate_data_split_with_external_indexing( external_indexing A list of data split indices in the order of training, validation, and test sets. Validation and test set are not required and can be left empty. + batch_size + batch size of each iteration. If `None`, do not minibatch + drop_last + drops last non-full batch """ if not isinstance(external_indexing, list): raise ValueError("External indexing is not of list type") @@ -132,6 +164,18 @@ def validate_data_split_with_external_indexing( n_train = len(external_indexing[0]) n_val = len(external_indexing[1]) + if batch_size is not None: + num_of_cells = n_train % batch_size + if (num_of_cells < 3 and num_of_cells > 0) and drop_last is False: + warnings.warn( + f"Last batch will have a small size of {num_of_cells} " + f"samples. Consider changing settings.batch_size or batch_size in model.train " + f"from currently {settings.batch_size} to avoid errors during model training " + f"or change the given external indices accordingly.", + UserWarning, + stacklevel=settings.warnings_stacklevel, + ) + return n_train, n_val @@ -145,7 +189,8 @@ class DataSplitter(pl.LightningDataModule): adata_manager :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. train_size - float, or None (default is 0.9) + float, or None (default is None, which is practicaly 0.9 and potentially adding small last + batch to validation cells) validation_size float, or None (default is None) shuffle_set_split @@ -182,7 +227,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, def __init__( self, adata_manager: AnnDataManager, - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, load_sparse_tensor: bool = False, @@ -192,7 +237,8 @@ def __init__( ): super().__init__() self.adata_manager = adata_manager - self.train_size = float(train_size) + self.train_size_is_none = not bool(train_size) + self.train_size = 0.9 if self.train_size_is_none else float(train_size) self.validation_size = validation_size self.shuffle_set_split = shuffle_set_split self.load_sparse_tensor = load_sparse_tensor @@ -205,10 +251,17 @@ def __init__( self.n_train, self.n_val = validate_data_split_with_external_indexing( self.adata_manager.adata.n_obs, self.external_indexing, + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, ) else: self.n_train, self.n_val = validate_data_split( - self.adata_manager.adata.n_obs, self.train_size, self.validation_size + self.adata_manager.adata.n_obs, + self.train_size, + self.validation_size, + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, + self.train_size_is_none, ) def setup(self, stage: str | None = None): @@ -298,7 +351,8 @@ class SemiSupervisedDataSplitter(pl.LightningDataModule): adata_manager :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. train_size - float, or None (default is 0.9) + float, or None (default is None, which is practicaly 0.9 and potentially adding small last + batch to validation cells) validation_size float, or None (default is None) shuffle_set_split @@ -333,7 +387,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, def __init__( self, adata_manager: AnnDataManager, - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, n_samples_per_label: int | None = None, @@ -343,7 +397,8 @@ def __init__( ): super().__init__() self.adata_manager = adata_manager - self.train_size = float(train_size) + self.train_size_is_none = not bool(train_size) + self.train_size = 0.9 if self.train_size_is_none else float(train_size) self.validation_size = validation_size self.shuffle_set_split = shuffle_set_split self.drop_last = kwargs.pop("drop_last", False) @@ -379,10 +434,17 @@ def setup(self, stage: str | None = None): n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing( n_labeled_idx, [labeled_idx_train, labeled_idx_val, labeled_idx_test], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, ) else: n_labeled_train, n_labeled_val = validate_data_split( - n_labeled_idx, self.train_size, self.validation_size + n_labeled_idx, + self.train_size, + self.validation_size, + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, + self.train_size_is_none, ) labeled_permutation = self._labeled_indices @@ -413,10 +475,17 @@ def setup(self, stage: str | None = None): n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing( n_unlabeled_idx, [unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, ) else: n_unlabeled_train, n_unlabeled_val = validate_data_split( - n_unlabeled_idx, self.train_size, self.validation_size + n_unlabeled_idx, + self.train_size, + self.validation_size, + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, + self.train_size_is_none, ) unlabeled_permutation = self._unlabeled_indices @@ -508,7 +577,8 @@ class DeviceBackedDataSplitter(DataSplitter): adata_manager :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. train_size - float, or None (default is 0.9) + float, or None (default is None, which is practicaly 0.9 and potentially adding small last + batch to validation cells) validation_size float, or None (default is None) %(param_accelerator)s @@ -536,7 +606,7 @@ class DeviceBackedDataSplitter(DataSplitter): def __init__( self, adata_manager: AnnDataManager, - train_size: float = 1.0, + train_size: float | None = None, validation_size: float | None = None, accelerator: str = "auto", device: int | str = "auto", diff --git a/src/scvi/external/cellassign/_model.py b/src/scvi/external/cellassign/_model.py index a6ecaeb5ef..6fa0d1baaa 100644 --- a/src/scvi/external/cellassign/_model.py +++ b/src/scvi/external/cellassign/_model.py @@ -143,7 +143,7 @@ def train( lr: float = 3e-3, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 1024, diff --git a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py index 659a91f406..aef13b8b6a 100644 --- a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py +++ b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py @@ -53,7 +53,7 @@ def __init__( adata_manager: AnnDataManager, background_indices: list[int], target_indices: list[int], - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, load_sparse_tensor: bool = False, @@ -78,10 +78,20 @@ def __init__( self.n_target = len(target_indices) if external_indexing is None: self.n_background_train, self.n_background_val = validate_data_split( - self.n_background, self.train_size, self.validation_size + self.n_background, + self.train_size, + self.validation_size, + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, + self.train_size_is_none, ) self.n_target_train, self.n_target_val = validate_data_split( - self.n_target, self.train_size, self.validation_size + self.n_target, + self.train_size, + self.validation_size, + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, + self.train_size_is_none, ) else: # we need to intersect the external indexing given with the bg/target indices @@ -93,6 +103,8 @@ def __init__( validate_data_split_with_external_indexing( self.n_background, [self.background_train_idx, self.background_val_idx, self.background_test_idx], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, ) ) self.background_train_idx, self.background_val_idx, self.background_test_idx = ( @@ -107,6 +119,8 @@ def __init__( self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing( self.n_target, [self.target_train_idx, self.target_val_idx, self.target_test_idx], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.drop_last, ) self.target_train_idx, self.target_val_idx, self.target_test_idx = ( self.target_train_idx.tolist(), diff --git a/src/scvi/external/contrastivevi/_model.py b/src/scvi/external/contrastivevi/_model.py index c5622d4588..fb31e14604 100644 --- a/src/scvi/external/contrastivevi/_model.py +++ b/src/scvi/external/contrastivevi/_model.py @@ -136,7 +136,7 @@ def train( max_epochs: int | None = None, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, load_sparse_tensor: bool = False, diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index 8cb78eaa9f..ac7b80f508 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -172,7 +172,7 @@ def train( accelerator: str = "auto", devices: int | list[int] | str = "auto", kappa: int = 5, - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 853e1a4a63..e7e7b7c2fa 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -201,7 +201,7 @@ def train( max_epochs: int | None = None, accelerator: str | None = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, batch_size: int = 128, early_stopping: bool = False, diff --git a/src/scvi/external/scbasset/_model.py b/src/scvi/external/scbasset/_model.py index c93ed9a254..daabbe6e36 100644 --- a/src/scvi/external/scbasset/_model.py +++ b/src/scvi/external/scbasset/_model.py @@ -106,7 +106,7 @@ def train( lr: float = 0.01, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/external/solo/_model.py b/src/scvi/external/solo/_model.py index 69f29bade4..e7ee325d24 100644 --- a/src/scvi/external/solo/_model.py +++ b/src/scvi/external/solo/_model.py @@ -293,7 +293,7 @@ def train( lr: float = 1e-3, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/external/velovi/_model.py b/src/scvi/external/velovi/_model.py index 39c5bc9c5f..3b5b339099 100644 --- a/src/scvi/external/velovi/_model.py +++ b/src/scvi/external/velovi/_model.py @@ -129,7 +129,7 @@ def train( weight_decay: float = 1e-2, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, batch_size: int = 256, early_stopping: bool = True, diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index e17fa012ca..85d307a25c 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -231,7 +231,7 @@ def train( lr: float = 1e-4, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/model/_peakvi.py b/src/scvi/model/_peakvi.py index a7204c17f3..1c5e011bba 100644 --- a/src/scvi/model/_peakvi.py +++ b/src/scvi/model/_peakvi.py @@ -151,7 +151,7 @@ def train( lr: float = 1e-4, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index dfab56bb74..b596e48afa 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -356,7 +356,7 @@ def train( max_epochs: int | None = None, n_samples_per_label: float | None = None, check_val_every_n_epoch: int | None = None, - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index 929ee1ad6c..a50c56e3ee 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -197,7 +197,7 @@ def train( lr: float = 4e-3, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 256, diff --git a/src/scvi/model/base/_jaxmixin.py b/src/scvi/model/base/_jaxmixin.py index 02a17a4418..bcaabbad0d 100644 --- a/src/scvi/model/base/_jaxmixin.py +++ b/src/scvi/model/base/_jaxmixin.py @@ -24,7 +24,7 @@ def train( max_epochs: int | None = None, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/model/base/_pyromixin.py b/src/scvi/model/base/_pyromixin.py index 8623764be3..f02de87343 100755 --- a/src/scvi/model/base/_pyromixin.py +++ b/src/scvi/model/base/_pyromixin.py @@ -92,7 +92,7 @@ def train( max_epochs: int | None = None, accelerator: str = "auto", device: int | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index 82e83704fd..ebace98445 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -24,7 +24,7 @@ def train( max_epochs: int | None = None, accelerator: str = "auto", devices: int | list[int] | str = "auto", - train_size: float = 0.9, + train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, load_sparse_tensor: bool = False, diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 36a3b7f7cd..49fe18e531 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -477,10 +477,21 @@ def test_scvi_n_obs_error(n_latent: int = 5): model = SCVI(adata, n_latent=n_latent) with pytest.raises(ValueError): model.train(1, train_size=1.0) - with pytest.warns(UserWarning): - # Warning is emitted if last batch less than 3 cells. + with pytest.raises(ValueError): + # Warning is emitted if last batch less than 3 cells + failure. model.train(1, train_size=1.0, batch_size=127) model.train(1, train_size=1.0, datasplitter_kwargs={"drop_last": True}) + + adata = synthetic_iid() + adata = adata[0:143].copy() + SCVI.setup_anndata(adata) + model = SCVI(adata, n_latent=n_latent) + with pytest.raises(ValueError): + model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1 + model.train( + 1, train_size=0.9, datasplitter_kwargs={"drop_last": True} + ) # np.ceil(n_cells * 0.9) % 128 == 1 + model.train(1) assert model.is_trained is True