Skip to content

Commit

Permalink
update args.
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyuDu committed Jun 27, 2023
1 parent eb1dd51 commit afbedfb
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions bemb/utils/run_helper_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def run(model: "LitBEMBFlex",
dataset_train: ChoiceDataset,
dataset_val: Optional[ChoiceDataset]=None,
dataset_test: Optional[ChoiceDataset]=None,
# model_optimizer: str='Adam',
batch_size: int=-1,
# learning_rate: float=0.01,
num_epochs: int=10,
num_workers: int=0,
device: Optional[str]=None,
check_val_every_n_epoch: Optional[int]=None,
**kwargs) -> "LitBEMBFlex":
"""_summary_
Expand All @@ -37,13 +36,13 @@ def run(model: "LitBEMBFlex",
dataset_train (ChoiceDataset): the dataset for training.
dataset_val (ChoiceDataset): an optional dataset for validation.
dataset_test (ChoiceDataset): an optional dataset for testing.
model_optimizer (str): the optimizer used to estimate the model. Defaults to 'Adam'.
batch_size (int, optional): batch size for model training. Defaults to -1.
learning_rate (float, optional): learning rate for model training. Defaults to 0.01.
num_epochs (int, optional): number of epochs for the training. Defaults to 10.
num_workers (int, optional): number of parallel workers for data loading. Defaults to 0.
device (Optional[str], optional): the device that trains the model, if None is specified, the function will
use the current device of the provided model. Defaults to None.
check_val_every_n_epoch (Optional[int], optional): the frequency of validation, if None is specified,
validation will be performed every 10% of total epochs. Defaults to None.
**kwargs: other keyword arguments for the pytorch lightning trainer, this is for users with experience in
pytorch lightning and wish to customize the training process.
Expand Down Expand Up @@ -93,7 +92,7 @@ def run(model: "LitBEMBFlex",
trainer = pl.Trainer(accelerator="cuda" if "cuda" in device else device, # note: "cuda:0" is not a accelerator name.
devices="auto",
max_epochs=num_epochs,
check_val_every_n_epoch=num_epochs // 10,
check_val_every_n_epoch=num_epochs // 10 if check_val_every_n_epoch is None else check_val_every_n_epoch,
log_every_n_steps=1,
callbacks=callbacks,
**kwargs)
Expand Down

0 comments on commit afbedfb

Please sign in to comment.