Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(train): in case of last batch <=2, move to validation if possible #3036

Merged
merged 21 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a7cbfd8
In case of last batch==1 , we moved one sample from validation to tra…
ori-kron-wis Oct 31, 2024
7208f46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
ae3f9fc
changed to auto moving up to 2 cells from train to valid if needed. b…
ori-kron-wis Nov 3, 2024
41416dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2024
21f4fcf
typo
ori-kron-wis Nov 3, 2024
b6e760e
Merge remote-tracking branch 'origin/Ori-3035-fix_min_bathc_size_less…
ori-kron-wis Nov 3, 2024
1e8550f
typo
ori-kron-wis Nov 3, 2024
a38423c
fix empty kwargs element case
ori-kron-wis Nov 3, 2024
7fda257
added release note
ori-kron-wis Nov 3, 2024
ad1e1b0
added release note
ori-kron-wis Nov 3, 2024
a247b95
Merge branch 'main' into Ori-3035-fix_min_bathc_size_less_than1
ori-kron-wis Nov 13, 2024
e278ce2
fix comments. we might want to update tutorials as well.
ori-kron-wis Nov 13, 2024
fb8a949
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
8accba4
fix pre-commit bug
ori-kron-wis Nov 13, 2024
90e15a7
Merge remote-tracking branch 'origin/Ori-3035-fix_min_bathc_size_less…
ori-kron-wis Nov 13, 2024
f2bd7cd
Merge remote-tracking branch 'origin/main' into Ori-3035-fix_min_bath…
ori-kron-wis Nov 14, 2024
5c80cc8
fix comments, revert logic
ori-kron-wis Nov 14, 2024
12a4719
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
5d921f6
fix comments
ori-kron-wis Nov 18, 2024
9f78358
fix comments
ori-kron-wis Nov 19, 2024
19bb8bf
fix comments
ori-kron-wis Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ to [Semantic Versioning]. Full commit history is available in the
#### Added

- Added adaptive handling for last training minibatch of 1-2 cells in case of
`datasplitter_kwargs={"drop_last": False}` by moving them into validation set, if available.
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into validation set, if available.
{pr}`3036`.
-

#### Fixed

Expand Down
73 changes: 25 additions & 48 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def validate_data_split(
n_samples: int,
train_size: float,
validation_size: float | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
batch_size_for_adaptive_last_batch: int | None = None,
):
"""Check data splitting parameters and return n_train and n_val.

Expand All @@ -38,10 +37,8 @@ def validate_data_split(
Size of train set. Need to be: 0 < train_size <= 1.
validation_size
Size of validation set. Need to be 0 <= validation_size < 1
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
batch_size_for_adaptive_last_batch
batch size of each iteration. If `None`, do not do adaptive last batch sizing
"""
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")
Expand All @@ -64,14 +61,15 @@ def validate_data_split(
"any of the aforementioned parameters."
)

if batch_size is not None and not drop_last:
if n_train % batch_size < 3 and n_train % batch_size > 0:
num_of_cells = n_train % batch_size
if batch_size_for_adaptive_last_batch is not None:
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
if (n_train % batch_size_for_adaptive_last_batch < 3 and
n_train % batch_size_for_adaptive_last_batch > 0):
num_of_cells = n_train % batch_size_for_adaptive_last_batch
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {batch_size} to avoid errors during model training. Those cells "
f"will be removed from the training set automatically",
f"from currently {batch_size_for_adaptive_last_batch} to avoid errors during model"
f" training. Those cells will be removed from the training set automatically",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
Expand All @@ -90,8 +88,6 @@ def validate_data_split(
def validate_data_split_with_external_indexing(
n_samples: int,
external_indexing: list[np.array, np.array, np.array] | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
):
"""Check data splitting parameters and return n_train and n_val.

Expand All @@ -102,10 +98,6 @@ def validate_data_split_with_external_indexing(
external_indexing
A list of data split indices in the order of training, validation, and test sets.
Validation and test set are not required and can be left empty.
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
"""
if not isinstance(external_indexing, list):
raise ValueError("External indexing is not of list type")
Expand Down Expand Up @@ -159,17 +151,6 @@ def validate_data_split_with_external_indexing(
n_train = len(external_indexing[0])
n_val = len(external_indexing[1])

if batch_size is not None and not drop_last:
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
if n_train % batch_size < 3 and n_train % batch_size > 0:
warnings.warn(
f"Last batch will have a small size of {n_train % batch_size} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {settings.batch_size} to avoid errors during model training "
f"or change the given external indices accordingly",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


Expand All @@ -183,7 +164,7 @@ class DataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 + adaptive last batch)
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -220,7 +201,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -230,7 +211,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_was_none = not bool(train_size)
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
self.train_size = 0.9 if self.train_size_was_none else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.load_sparse_tensor = load_sparse_tensor
Expand All @@ -243,16 +225,14 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)

def setup(self, stage: str | None = None):
Expand Down Expand Up @@ -342,7 +322,7 @@ class SemiSupervisedDataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 + adaptive last batch)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -377,7 +357,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
n_samples_per_label: int | None = None,
Expand All @@ -387,7 +367,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_was_none = not bool(train_size)
self.train_size = 0.9 if train_size is None else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.drop_last = kwargs.pop("drop_last", False)
Expand Down Expand Up @@ -423,16 +404,14 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)

labeled_permutation = self._labeled_indices
Expand Down Expand Up @@ -463,16 +442,14 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.data_loader_kwargs.pop("batch_size",settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)

unlabeled_permutation = self._unlabeled_indices
Expand Down Expand Up @@ -564,7 +541,7 @@ class DeviceBackedDataSplitter(DataSplitter):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 + adaptive last batch)
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
validation_size
float, or None (default is None)
%(param_accelerator)s
Expand Down Expand Up @@ -592,7 +569,7 @@ class DeviceBackedDataSplitter(DataSplitter):
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 1.0,
train_size: float | None = None,
validation_size: float | None = None,
accelerator: str = "auto",
device: int | str = "auto",
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def train(
lr: float = 3e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 1024,
Expand Down
14 changes: 5 additions & 9 deletions src/scvi/external/contrastivevi/_contrastive_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
adata_manager: AnnDataManager,
background_indices: list[int],
target_indices: list[int],
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down Expand Up @@ -81,15 +81,15 @@ def __init__(
self.n_background,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
self.train_size_was_none and not self.drop_last else None,
)
else:
# we need to intersect the external indexing given with the bg/target indices
Expand All @@ -101,8 +101,6 @@ def __init__(
validate_data_split_with_external_indexing(
self.n_background,
[self.background_train_idx, self.background_val_idx, self.background_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
)
self.background_train_idx, self.background_val_idx, self.background_test_idx = (
Expand All @@ -117,8 +115,6 @@ def __init__(
self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing(
self.n_target,
[self.target_train_idx, self.target_val_idx, self.target_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.target_train_idx, self.target_val_idx, self.target_test_idx = (
self.target_train_idx.tolist(),
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/contrastivevi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train(
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
kappa: int = 5,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(
max_epochs: int | None = None,
accelerator: str | None = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 128,
early_stopping: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/scbasset/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train(
lr: float = 0.01,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def train(
lr: float = 1e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/velovi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def train(
weight_decay: float = 1e-2,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 256,
early_stopping: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def train(
max_epochs: int | None = None,
n_samples_per_label: float | None = None,
check_val_every_n_epoch: int | None = None,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def train(
lr: float = 4e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 256,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/base/_jaxmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/base/_pyromixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
device: int | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/base/_training_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down
Loading
Loading