diff --git a/bemb/utils/run_helper_lightning.py b/bemb/utils/run_helper_lightning.py index 8df4131..bc4e402 100644 --- a/bemb/utils/run_helper_lightning.py +++ b/bemb/utils/run_helper_lightning.py @@ -23,12 +23,11 @@ def run(model: "LitBEMBFlex", dataset_train: ChoiceDataset, dataset_val: Optional[ChoiceDataset]=None, dataset_test: Optional[ChoiceDataset]=None, - # model_optimizer: str='Adam', batch_size: int=-1, - # learning_rate: float=0.01, num_epochs: int=10, num_workers: int=0, device: Optional[str]=None, + check_val_every_n_epoch: Optional[int]=None, **kwargs) -> "LitBEMBFlex": """_summary_ @@ -37,13 +36,13 @@ def run(model: "LitBEMBFlex", dataset_train (ChoiceDataset): the dataset for training. dataset_val (ChoiceDataset): an optional dataset for validation. dataset_test (ChoiceDataset): an optional dataset for testing. - model_optimizer (str): the optimizer used to estimate the model. Defaults to 'Adam'. batch_size (int, optional): batch size for model training. Defaults to -1. - learning_rate (float, optional): learning rate for model training. Defaults to 0.01. num_epochs (int, optional): number of epochs for the training. Defaults to 10. num_workers (int, optional): number of parallel workers for data loading. Defaults to 0. device (Optional[str], optional): the device that trains the model, if None is specified, the function will use the current device of the provided model. Defaults to None. + check_val_every_n_epoch (Optional[int], optional): the frequency of validation, if None is specified, + validation will be performed every 10% of total epochs. Defaults to None. **kwargs: other keyword arguments for the pytorch lightning trainer, this is for users with experience in pytorch lightning and wish to customize the training process. @@ -93,7 +92,7 @@ def run(model: "LitBEMBFlex", trainer = pl.Trainer(accelerator="cuda" if "cuda" in device else device, # note: "cuda:0" is not a accelerator name. devices="auto", max_epochs=num_epochs, - check_val_every_n_epoch=num_epochs // 10, + check_val_every_n_epoch=num_epochs // 10 if check_val_every_n_epoch is None else check_val_every_n_epoch, log_every_n_steps=1, callbacks=callbacks, **kwargs)