-
Notifications
You must be signed in to change notification settings - Fork 374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(train): in case of last batch <=2, move to validation if possible #3036
Changes from 18 commits
a7cbfd8
7208f46
ae3f9fc
41416dc
21f4fcf
b6e760e
1e8550f
a38423c
7fda257
ad1e1b0
a247b95
e278ce2
fb8a949
8accba4
90e15a7
f2bd7cd
5c80cc8
12a4719
5d921f6
9f78358
19bb8bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,14 @@ | |
from scvi.utils._docstrings import devices_dsp | ||
|
||
|
||
def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None): | ||
def validate_data_split( | ||
n_samples: int, | ||
train_size: float, | ||
validation_size: float | None = None, | ||
batch_size: int | None = None, | ||
drop_last: bool | int = False, | ||
train_size_is_none: bool | int = True, | ||
): | ||
"""Check data splitting parameters and return n_train and n_val. | ||
|
||
Parameters | ||
|
@@ -32,21 +39,18 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa | |
Size of train set. Need to be: 0 < train_size <= 1. | ||
validation_size | ||
Size of validation set. Need to be 0 <= validation_size < 1 | ||
batch_size | ||
batch size of each iteration. If `None`, do not minibatch | ||
drop_last | ||
drops last non-full batch | ||
train_size_is_none | ||
Whether the user did not explicitly input train_size | ||
""" | ||
if train_size > 1.0 or train_size <= 0.0: | ||
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1") | ||
|
||
n_train = ceil(train_size * n_samples) | ||
|
||
if n_train % settings.batch_size < 3 and n_train % settings.batch_size > 0: | ||
warnings.warn( | ||
f"Last batch will have a small size of {n_train % settings.batch_size}" | ||
f"samples. Consider changing settings.batch_size or batch_size in model.train" | ||
f"currently {settings.batch_size} to avoid errors during model training.", | ||
UserWarning, | ||
stacklevel=settings.warnings_stacklevel, | ||
) | ||
|
||
if validation_size is None: | ||
n_val = n_samples - n_train | ||
elif validation_size >= 1.0 or validation_size < 0.0: | ||
|
@@ -59,16 +63,41 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa | |
if n_train == 0: | ||
raise ValueError( | ||
f"With n_samples={n_samples}, train_size={train_size} and " | ||
f"validation_size={validation_size}, the resulting train set will be empty. Adjust" | ||
f"validation_size={validation_size}, the resulting train set will be empty. Adjust " | ||
"any of the aforementioned parameters." | ||
) | ||
|
||
if batch_size is not None: | ||
num_of_cells = n_train % batch_size | ||
if (num_of_cells < 3 and num_of_cells > 0) and not ( | ||
num_of_cells == 1 and drop_last is True | ||
): | ||
warnings.warn( | ||
f"Last batch will have a small size of {num_of_cells} " | ||
f"samples. Consider changing settings.batch_size or batch_size in model.train " | ||
f"from currently {batch_size} to avoid errors during model training, " | ||
f"or use drop_last parameter if there is 1 cell left", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one is wrong There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whats wrong? |
||
UserWarning, | ||
stacklevel=settings.warnings_stacklevel, | ||
) | ||
if train_size_is_none: | ||
n_train -= num_of_cells | ||
if n_val > 0: | ||
n_val += num_of_cells | ||
warnings.warn( | ||
f"{num_of_cells} cells moved from training set to validation set", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to avoid error during training. Set train_size to a fixed size to avoid this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is 0.9 if given as None during init. did you mean something else? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want a message how to avound moving cells from training to validation. The user avoids it when stating train_size=0.9 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
UserWarning, | ||
stacklevel=settings.warnings_stacklevel, | ||
) | ||
|
||
return n_train, n_val | ||
|
||
|
||
def validate_data_split_with_external_indexing( | ||
n_samples: int, | ||
external_indexing: list[np.array, np.array, np.array] | None = None, | ||
batch_size: int | None = None, | ||
drop_last: bool | int = False, | ||
): | ||
"""Check data splitting parameters and return n_train and n_val. | ||
|
||
|
@@ -79,6 +108,10 @@ def validate_data_split_with_external_indexing( | |
external_indexing | ||
A list of data split indices in the order of training, validation, and test sets. | ||
Validation and test set are not required and can be left empty. | ||
batch_size | ||
batch size of each iteration. If `None`, do not minibatch | ||
drop_last | ||
drops last non-full batch | ||
""" | ||
if not isinstance(external_indexing, list): | ||
raise ValueError("External indexing is not of list type") | ||
|
@@ -132,6 +165,21 @@ def validate_data_split_with_external_indexing( | |
n_train = len(external_indexing[0]) | ||
n_val = len(external_indexing[1]) | ||
|
||
if batch_size is not None: | ||
num_of_cells = n_train % batch_size | ||
if (num_of_cells < 3 and num_of_cells > 0) and not ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update drop_last here. |
||
num_of_cells == 1 and drop_last is True | ||
): | ||
warnings.warn( | ||
f"Last batch will have a small size of {num_of_cells} " | ||
f"samples. Consider changing settings.batch_size or batch_size in model.train " | ||
f"from currently {settings.batch_size} to avoid errors during model training " | ||
f"or change the given external indices accordingly or use drop_last parameter if " | ||
f"there is 1 cell left", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is confusing. See comment above - line 179 should be gone. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
UserWarning, | ||
stacklevel=settings.warnings_stacklevel, | ||
) | ||
|
||
return n_train, n_val | ||
|
||
|
||
|
@@ -145,7 +193,8 @@ class DataSplitter(pl.LightningDataModule): | |
adata_manager | ||
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. | ||
train_size | ||
float, or None (default is 0.9) | ||
float, or None (default is None, which is practicaly 0.9 and potentially adding small last | ||
batch to validation cells) | ||
validation_size | ||
float, or None (default is None) | ||
shuffle_set_split | ||
|
@@ -182,7 +231,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, | |
def __init__( | ||
self, | ||
adata_manager: AnnDataManager, | ||
train_size: float = 0.9, | ||
train_size: float | None = None, | ||
validation_size: float | None = None, | ||
shuffle_set_split: bool = True, | ||
load_sparse_tensor: bool = False, | ||
|
@@ -192,7 +241,8 @@ def __init__( | |
): | ||
super().__init__() | ||
self.adata_manager = adata_manager | ||
self.train_size = float(train_size) | ||
self.train_size_is_none = not bool(train_size) | ||
self.train_size = 0.9 if self.train_size_is_none else float(train_size) | ||
self.validation_size = validation_size | ||
self.shuffle_set_split = shuffle_set_split | ||
self.load_sparse_tensor = load_sparse_tensor | ||
|
@@ -205,10 +255,17 @@ def __init__( | |
self.n_train, self.n_val = validate_data_split_with_external_indexing( | ||
self.adata_manager.adata.n_obs, | ||
self.external_indexing, | ||
self.data_loader_kwargs.pop("batch_size", settings.batch_size), | ||
ori-kron-wis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.drop_last, | ||
) | ||
else: | ||
self.n_train, self.n_val = validate_data_split( | ||
self.adata_manager.adata.n_obs, self.train_size, self.validation_size | ||
self.adata_manager.adata.n_obs, | ||
self.train_size, | ||
self.validation_size, | ||
self.data_loader_kwargs.pop("batch_size", settings.batch_size), | ||
self.drop_last, | ||
self.train_size_is_none, | ||
) | ||
|
||
def setup(self, stage: str | None = None): | ||
|
@@ -298,7 +355,8 @@ class SemiSupervisedDataSplitter(pl.LightningDataModule): | |
adata_manager | ||
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. | ||
train_size | ||
float, or None (default is 0.9) | ||
float, or None (default is None, which is practicaly 0.9 and potentially adding small last | ||
batch to validation cells) | ||
validation_size | ||
float, or None (default is None) | ||
shuffle_set_split | ||
|
@@ -333,7 +391,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, | |
def __init__( | ||
self, | ||
adata_manager: AnnDataManager, | ||
train_size: float = 0.9, | ||
train_size: float | None = None, | ||
validation_size: float | None = None, | ||
shuffle_set_split: bool = True, | ||
n_samples_per_label: int | None = None, | ||
|
@@ -343,7 +401,8 @@ def __init__( | |
): | ||
super().__init__() | ||
self.adata_manager = adata_manager | ||
self.train_size = float(train_size) | ||
self.train_size_is_none = not bool(train_size) | ||
self.train_size = 0.9 if train_size is None else float(train_size) | ||
self.validation_size = validation_size | ||
self.shuffle_set_split = shuffle_set_split | ||
self.drop_last = kwargs.pop("drop_last", False) | ||
|
@@ -379,10 +438,17 @@ def setup(self, stage: str | None = None): | |
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing( | ||
n_labeled_idx, | ||
[labeled_idx_train, labeled_idx_val, labeled_idx_test], | ||
self.data_loader_kwargs.pop("batch_size", settings.batch_size), | ||
self.drop_last, | ||
) | ||
else: | ||
n_labeled_train, n_labeled_val = validate_data_split( | ||
n_labeled_idx, self.train_size, self.validation_size | ||
n_labeled_idx, | ||
self.train_size, | ||
self.validation_size, | ||
self.data_loader_kwargs.pop("batch_size", settings.batch_size), | ||
self.drop_last, | ||
self.train_size_is_none, | ||
) | ||
|
||
labeled_permutation = self._labeled_indices | ||
|
@@ -413,10 +479,17 @@ def setup(self, stage: str | None = None): | |
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing( | ||
n_unlabeled_idx, | ||
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test], | ||
self.data_loader_kwargs.pop("batch_size", settings.batch_size), | ||
self.drop_last, | ||
) | ||
else: | ||
n_unlabeled_train, n_unlabeled_val = validate_data_split( | ||
n_unlabeled_idx, self.train_size, self.validation_size | ||
n_unlabeled_idx, | ||
self.train_size, | ||
self.validation_size, | ||
self.data_loader_kwargs.pop("batch_size", settings.batch_size), | ||
self.drop_last, | ||
self.train_size_is_none, | ||
) | ||
|
||
unlabeled_permutation = self._unlabeled_indices | ||
|
@@ -508,7 +581,8 @@ class DeviceBackedDataSplitter(DataSplitter): | |
adata_manager | ||
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. | ||
train_size | ||
float, or None (default is 0.9) | ||
float, or None (default is None, which is practicaly 0.9 and potentially adding small last | ||
batch to validation cells) | ||
validation_size | ||
float, or None (default is None) | ||
%(param_accelerator)s | ||
|
@@ -536,7 +610,7 @@ class DeviceBackedDataSplitter(DataSplitter): | |
def __init__( | ||
self, | ||
adata_manager: AnnDataManager, | ||
train_size: float = 1.0, | ||
train_size: float | None = None, | ||
validation_size: float | None = None, | ||
accelerator: str = "auto", | ||
device: int | str = "auto", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one if confusing if drop_last it will drop the last batch no matter how many cells it contains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. this part is not about the drop_last logic.
this part of code is about when to show the warning. and a warning will be shown if we do have 1-2 cells in last batch but also when user didnt select drop_last with cell==1. whoever did it, doesnt need to see this warning, because it will not fail for him
only if also train_size_is_none there will be adaptive cell transferring to the validation set, if exists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again the check doesn't need to be num_of_cells == 1 and drop_last is True. It should be: (num_of_cells < 3 and num_of_cells > 0) and drop_last is False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok