Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(train): in case of last batch <=2, move to validation if possible #3036

Merged
merged 21 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a7cbfd8
In case of last batch==1 , we moved one sample from validation to tra…
ori-kron-wis Oct 31, 2024
7208f46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
ae3f9fc
changed to auto moving up to 2 cells from train to valid if needed. b…
ori-kron-wis Nov 3, 2024
41416dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2024
21f4fcf
typo
ori-kron-wis Nov 3, 2024
b6e760e
Merge remote-tracking branch 'origin/Ori-3035-fix_min_bathc_size_less…
ori-kron-wis Nov 3, 2024
1e8550f
typo
ori-kron-wis Nov 3, 2024
a38423c
fix empty kwargs element case
ori-kron-wis Nov 3, 2024
7fda257
added release note
ori-kron-wis Nov 3, 2024
ad1e1b0
added release note
ori-kron-wis Nov 3, 2024
a247b95
Merge branch 'main' into Ori-3035-fix_min_bathc_size_less_than1
ori-kron-wis Nov 13, 2024
e278ce2
fix comments. we might want to update tutorials as well.
ori-kron-wis Nov 13, 2024
fb8a949
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
8accba4
fix pre-commit bug
ori-kron-wis Nov 13, 2024
90e15a7
Merge remote-tracking branch 'origin/Ori-3035-fix_min_bathc_size_less…
ori-kron-wis Nov 13, 2024
f2bd7cd
Merge remote-tracking branch 'origin/main' into Ori-3035-fix_min_bath…
ori-kron-wis Nov 14, 2024
5c80cc8
fix comments, revert logic
ori-kron-wis Nov 14, 2024
12a4719
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
5d921f6
fix comments
ori-kron-wis Nov 18, 2024
9f78358
fix comments
ori-kron-wis Nov 19, 2024
19bb8bf
fix comments
ori-kron-wis Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Added adaptive handling for last training minibatch of 1-2 cells in case of
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into
validation set, if available.
{pr}`3036`.

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
Expand Down
114 changes: 92 additions & 22 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from scvi.utils._docstrings import devices_dsp


def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None):
def validate_data_split(
n_samples: int,
train_size: float,
validation_size: float | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
train_size_is_none: bool | int = True,
):
"""Check data splitting parameters and return n_train and n_val.

Parameters
Expand All @@ -32,21 +39,18 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
Size of train set. Need to be: 0 < train_size <= 1.
validation_size
Size of validation set. Need to be 0 <= validation_size < 1
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
train_size_is_none
Whether the user did not explicitly input train_size
"""
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")

n_train = ceil(train_size * n_samples)

if n_train % settings.batch_size < 3 and n_train % settings.batch_size > 0:
warnings.warn(
f"Last batch will have a small size of {n_train % settings.batch_size}"
f"samples. Consider changing settings.batch_size or batch_size in model.train"
f"currently {settings.batch_size} to avoid errors during model training.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

if validation_size is None:
n_val = n_samples - n_train
elif validation_size >= 1.0 or validation_size < 0.0:
Expand All @@ -59,16 +63,40 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
if n_train == 0:
raise ValueError(
f"With n_samples={n_samples}, train_size={train_size} and "
f"validation_size={validation_size}, the resulting train set will be empty. Adjust"
f"validation_size={validation_size}, the resulting train set will be empty. Adjust "
"any of the aforementioned parameters."
)

if batch_size is not None:
num_of_cells = n_train % batch_size
if (num_of_cells < 3 and num_of_cells > 0) and drop_last is False:
if not train_size_is_none:
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {batch_size} to avoid errors during model training.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
else:
n_train -= num_of_cells
if n_val > 0:
n_val += num_of_cells
warnings.warn(
f"{num_of_cells} cells moved from training set to validation set."
f" if you want to avoid it please use train_size parameter during train.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


def validate_data_split_with_external_indexing(
n_samples: int,
external_indexing: list[np.array, np.array, np.array] | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
):
"""Check data splitting parameters and return n_train and n_val.

Expand All @@ -79,6 +107,10 @@ def validate_data_split_with_external_indexing(
external_indexing
A list of data split indices in the order of training, validation, and test sets.
Validation and test set are not required and can be left empty.
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
"""
if not isinstance(external_indexing, list):
raise ValueError("External indexing is not of list type")
Expand Down Expand Up @@ -132,6 +164,18 @@ def validate_data_split_with_external_indexing(
n_train = len(external_indexing[0])
n_val = len(external_indexing[1])

if batch_size is not None:
num_of_cells = n_train % batch_size
if (num_of_cells < 3 and num_of_cells > 0) and drop_last is False:
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {settings.batch_size} to avoid errors during model training "
f"or change the given external indices accordingly.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


Expand All @@ -145,7 +189,8 @@ class DataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -182,7 +227,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -192,7 +237,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_is_none = not bool(train_size)
self.train_size = 0.9 if self.train_size_is_none else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.load_sparse_tensor = load_sparse_tensor
Expand All @@ -205,10 +251,17 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs, self.train_size, self.validation_size
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

def setup(self, stage: str | None = None):
Expand Down Expand Up @@ -298,7 +351,8 @@ class SemiSupervisedDataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -333,7 +387,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
n_samples_per_label: int | None = None,
Expand All @@ -343,7 +397,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_is_none = not bool(train_size)
self.train_size = 0.9 if self.train_size_is_none else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.drop_last = kwargs.pop("drop_last", False)
Expand Down Expand Up @@ -379,10 +434,17 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx, self.train_size, self.validation_size
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

labeled_permutation = self._labeled_indices
Expand Down Expand Up @@ -413,10 +475,17 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx, self.train_size, self.validation_size
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

unlabeled_permutation = self._unlabeled_indices
Expand Down Expand Up @@ -508,7 +577,8 @@ class DeviceBackedDataSplitter(DataSplitter):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
%(param_accelerator)s
Expand Down Expand Up @@ -536,7 +606,7 @@ class DeviceBackedDataSplitter(DataSplitter):
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 1.0,
train_size: float | None = None,
validation_size: float | None = None,
accelerator: str = "auto",
device: int | str = "auto",
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def train(
lr: float = 3e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 1024,
Expand Down
20 changes: 17 additions & 3 deletions src/scvi/external/contrastivevi/_contrastive_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
adata_manager: AnnDataManager,
background_indices: list[int],
target_indices: list[int],
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -78,10 +78,20 @@ def __init__(
self.n_target = len(target_indices)
if external_indexing is None:
self.n_background_train, self.n_background_val = validate_data_split(
self.n_background, self.train_size, self.validation_size
self.n_background,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target, self.train_size, self.validation_size
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
else:
# we need to intersect the external indexing given with the bg/target indices
Expand All @@ -93,6 +103,8 @@ def __init__(
validate_data_split_with_external_indexing(
self.n_background,
[self.background_train_idx, self.background_val_idx, self.background_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
)
self.background_train_idx, self.background_val_idx, self.background_test_idx = (
Expand All @@ -107,6 +119,8 @@ def __init__(
self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing(
self.n_target,
[self.target_train_idx, self.target_val_idx, self.target_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.target_train_idx, self.target_val_idx, self.target_test_idx = (
self.target_train_idx.tolist(),
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/contrastivevi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train(
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
kappa: int = 5,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(
max_epochs: int | None = None,
accelerator: str | None = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 128,
early_stopping: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/scbasset/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train(
lr: float = 0.01,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def train(
lr: float = 1e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/velovi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def train(
weight_decay: float = 1e-2,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 256,
early_stopping: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
Loading