Skip to content

Commit

Permalink
Make number of epochs "finetuning-equivalent" (#344)
Browse files Browse the repository at this point in the history
  • Loading branch information
lballes authored Jul 26, 2023
1 parent 38d9143 commit 740cdf6
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/renate/cli/parsing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]:
"max_epochs": {
"type": int,
"default": defaults.MAX_EPOCHS,
"help": f"Number of epochs trained at most. Default: {defaults.MAX_EPOCHS}",
"help": "Maximum number of (finetuning-equivalent) epochs. "
f"Default: {defaults.MAX_EPOCHS}",
"argument_group": OPTIONAL_ARGS_GROUP,
},
"task_id": {
Expand Down
2 changes: 1 addition & 1 deletion src/renate/cli/run_experiment_with_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run(self):
"--max_epochs",
type=int,
default=defaults.MAX_EPOCHS,
help=f"Number of epochs trained at most. Default: {defaults.MAX_EPOCHS}",
help=f"Maximum number of (finetuning-equiv.) epochs. Default: {defaults.MAX_EPOCHS}",
)
argument_group.add_argument(
"--seed",
Expand Down
4 changes: 3 additions & 1 deletion src/renate/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def run_training_job(
metric: Name of metric to optimize.
backend: Whether to run jobs locally (`local`) or on SageMaker (`sagemaker`).
updater: Updater used for model update.
max_epochs: Maximum number of epochs the model is trained.
max_epochs: The maximum number of epochs used to train the model. For comparability between
methods, epochs are interpreted as "finetuning-equivalent". That is, one epoch is
defined as `len(current_task_dataset) / batch_size` update steps.
task_id: Unique identifier for the current task.
chunk_id: Unique identifier for the current data chunk.
input_state_url: Path to the Renate model state.
Expand Down
8 changes: 7 additions & 1 deletion src/renate/updaters/model_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ class ModelUpdater(abc.ABC):
state available) or replace current arguments of the learner.
input_state_folder: Folder used by Renate to store files for current state.
output_state_folder: Folder used by Renate to store files for next state.
max_epochs: The maximum number of epochs used to train the model.
max_epochs: The maximum number of epochs used to train the model. For comparability between
methods, epochs are interpreted as "finetuning-equivalent". That is, one epoch is
defined as `len(current_task_dataset) / batch_size` update steps.
train_transform: The transformation applied during training.
train_target_transform: The target transformation applied during testing.
test_transform: The transformation at test time.
Expand Down Expand Up @@ -408,10 +410,14 @@ def _fit_learner(
)

strategy = create_strategy(self._devices, self._strategy)
# Finetuning-equivalent epochs.
num_batches = len(learner._train_dataset) // learner._batch_size
num_batches += min(len(learner._train_dataset) % learner._batch_size, 1)
trainer = Trainer(
accelerator=self._accelerator,
devices=self._devices,
max_epochs=self._max_epochs,
limit_train_batches=num_batches,
callbacks=callbacks,
logger=self._logger,
enable_progress_bar=False,
Expand Down
4 changes: 2 additions & 2 deletions test/integration_tests/configs/suites/quick/joint.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"dataset": "fashionmnist.json",
"backend": "local",
"job_name": "iid-mlp-joint",
"expected_accuracy_linux": [[0.8639000058174133, 0.8639000058174133], [0.8618000149726868, 0.8618000149726868]],
"expected_accuracy_darwin": [[0.859499990940094, 0.859499990940094]]
"expected_accuracy_linux": [[0.8495000004768372, 0.8495000004768372], [0.8427000045776367, 0.8427000045776367]],
"expected_accuracy_darwin": [[0.84170001745224, 0.84170001745224]]
}

0 comments on commit 740cdf6

Please sign in to comment.