diff --git a/docs/source/index.rst b/docs/source/index.rst index b089f5f..34cc1c1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,6 +13,7 @@ Welcome to Helios's documentation! getting-started tutorial quick-ref + plugins .. include:: ../../README.rst diff --git a/docs/source/plugins.rst b/docs/source/plugins.rst new file mode 100644 index 0000000..02eb03e --- /dev/null +++ b/docs/source/plugins.rst @@ -0,0 +1,380 @@ +Plug-Ins +############ + +Plug-in API +=========== + +Helios offers a plug-in system that allows users to override certain elements of the +training loops. All plug-ins *must* derive from the main +:py:class:`~helios.plugins.plugin.Plugin` interface. The list of functions that are +available is similar to the ones offered by the :py:class:`~helios.model.model.Model` +class and follow the same call order. For example, the training loop would look something +like this: + +.. code-block:: python + + plugin.on_training_start() + model.on_training_start() + model.train() + for epoch in epoch: + model.on_training_epoch_start() + + for batch in dataloader: + plugin.process_training_batch(batch) + model.on_training_batch_start() + model.train_step() + model.on_training_batch_end() + + model.on_training_epoch_end() + model.on_training_end() + plugin.on_training_end() + +Notice that the plug-in functions are always called *before* the corresponding model +functions. This is to allow the plug-ins to override the model if necessary or to set +state that can be later accessed by the model. The model (and the dataloader) can access +the plug-ins through the reference to the trainer: + +.. code-block:: python + + def train_step(...): + model.trainer.plugins["foo"] # <- Access plug-in with name "foo" + +Batch Processing +---------------- + +The major difference between the functions of the model and the plug-ins is the lack of a +:py:meth:`~helios.model.model.Model.train_step` function (and similarly for validation and +testing). Instead, the plug-ins have 3 equivalent functions: + +* :py:meth:`~helios.plugins.plugin.Plugin.process_training_batch` +* :py:meth:`~helios.plugins.plugin.Plugin.process_validation_batch` +* :py:meth:`~helios.plugins.plugin.Plugin.process_testing_batch` + +These functions receive the batch as an argument and return the processed batch. They can +be used for a variety of tasks such as moving tensors to a given device, filtering batch +entries, converting values, etc. For example, suppose we wanted to reduce the training +batch size by removing elements. We could do this as follows: + +.. code-block:: python + + def process_training_batch(self, batch: list[torch.Tensor]) -> list[torch.Tensor]: + return batch[:2] # <- Take the first two elements of the batch. + +When the model's :py:meth:`~helios.model.model.Model.train_step` function is called, it +will only receive the first 2 tensors of the original batch. + +Plug-in Registration +-------------------- + +The trainer contains the :py:attr:`~helios.trainer.Trainer.plugins` table in which all +plug-ins must be registered. To facilitate this, the plug-in base class requires a string +to act as the key with which it will be added to the table. In addition, it provides a +function that automatically registers the plug-in itself into the plug-in table. The +function can be easily invoked from the +:py:meth:`~helios.plugins.plugin.Plugin.configure_trainer` function as follows: + +.. code-block:: python + + import helios.plugins as hlp + import helios.trainer as hlt + + class MyPlugin(hlp.Plugin): + def __init__(self): + super().__init__("my_plugin") + + def configure_trainer(self, trainer: hlt.Trainer) -> None: + self._register_in_trainer(trainer) # <- Automatically registers the plug-in. + +.. note:: + All plug-ins that are shipped with Helios contain a ``plugin_id`` field as a class + variable that can be used to easily access them from the trainer table. You are + *encouraged* to always use this instead of manually typing in the key. For example, + with the :py:class:`~helios.plugins.plugin.CUDAPlugin`, you could access it like this: + + .. code-block:: python + + import helios.plugins as hlp + import helios.trainer as hlt + + trainer = hlt.Trainer(...) + plugin = hlp.CUDAPlugin() + plugin.configure_trainer(trainer) + trainer.plugins[CUDAPlugin.plugin_id] # <- Access the plug-in like this. + +Unique Traits +------------- + +In order to avoid conflicts, the plug-in API designates certain functions as *unique*. In +this context, a plug-in with a *unique* override may only appear exactly *once* in the +:py:attr:`~helios.trainer.Trainer.plugins` table from the trainer. If a second plug-in +with that specific override is added, an exception is raised. The full list of overrides +can be found in the :py:class:`~helios.plugins.plugin.UniquePluginOverrides` struct. Each +plug-in has a copy found under :py:attr:`~helios.plugins.plugin.Plugin.unique_overrides` +and *must* be filled in with the corresponding information for each plug-in. + +For example, suppose we want to build a new plug-in that can modify the training batch and +cause training to stop early. We would then set the structure as follows: + +.. code-block:: python + + import helios.plugins as hlp + + class MyPlugin(hlp.Plugin): + def __init__(self): + super().__init__("my_plugin") + + self.unique_overrides.training_batch = True + self.unique_overrides.should_training_stop = True + + def process_training_batch(...): + ... + + def should_training_stop(...): + ... + +.. warning:: + Attempting to add two plug-ins with the same overrides **will** result in an exception + being raised. + + +Built-in Plug-ins +================= + +Helios ships with the following built-in plug-ins, which will be discussed in the +following sections: + +* :py:class:`~helios.plugins.plugin.CUDAPlugin` +* :py:class:`~helios.plugins.optuna.OptunaPlugin` + +CUDA Plug-in +------------ + +The :py:class:`~helios.plugins.plugin.CUDAPlugin` is designed to move tensors from the +batches returned by the datasets to the current CUDA device. The device is determined by +the trainer when training starts with the same logic used to assign the device to the +model. Specifically: + +* If training isn't distributed, the device is the GPU that is used for training. +* If training is distributed, then the device corresponds to the GPU assigned to the given + process (i.e. the local rank). + +.. warning:: + As its name implies, the :py:class:`~helios.plugins.plugin.CUDAPlugin` **requires** + CUDA to be enabled to function. If it isn't, an exception is raised. + +The plug-in is designed to handle the following types of batches: + +* :py:class:`torch.Tensor`, +* Lists of :py:class:`torch.Tensor`, +* Tuples of :py:class:`torch.Tensor`, and +* Dictionaries whose values are :py:class:`torch.Tensor`. + +.. note:: + The contents of the containers need not be homogeneous. In other words, it is perfectly + valid some entries in a dictionary to *not* be tensors. The plug-in will automatically + recognise tensors and move them to the device. + +.. warning:: + The plug-in is **not** designed to handle nested containers. For instance, if your + batch is a dictionary containing arrays of tensors, then the plug-in will **not** + recognise the tensors contained in the arrays and move them. + +In the event that your batch requires special handling, you can easily derive the class +and override the function that moves the tensors to the device. For example, suppose that +our batch consists of a dictionary of arrays of tensors. Then we would do the following: + +.. code-block:: python + + import helios.plugins as hlp + import torch + + class MyCUDAPlugin(hlp.CUDAPlugin): + # Only need to override this function. Everything else will work automatically. + def _move_collection_to_device( + self, batch: dict[str, list[torch.Tensor]] + ) -> dict[str, list[torch.Tensor]]: + for key, value in batch.items(): + for i in range(len(value)): + value[i] = value[i].to(self.device) + batch[key] = value + + return batch + +.. note:: + The :py:class:`~helios.plugins.plugin.CUDAPlugin` is automatically registered in the + plug-in registry and can therefore be created through the + :py:func:`~helios.plugins.plugin.create_plugin` function. + +Optuna Plug-in +-------------- + +In order to use the Optuna plugin, we first need to install `optuna +`__:: + + pip install -U optuna + +.. warning:: + Optuna is a **required** dependency for this plug-in. If it isn't installed, an + exception is raised. + +The plug-in will automatically integrate with Optuna for hyper-parameter optimisation by +performing the following tasks: + +* Register the :py:class:`optuna.TrialPruned` exception type with the trainer for correct + trial pruning. +* Automatically update the :py:class:`~helios.model.model.Model` so the save name is + consistent and allow trials to continue if they're interrupted. +* Correctly handle reporting and pruning for regular and distributed training. + +A full example for how to use this plug-in can be found `here +`__, but we +will discuss the basics below. For the sake of simplicity, the code is identical to the +`cifar10 `__ +example, so we will only focus on the necessary code to use the plug-in. + +Plug-in Registration +^^^^^^^^^^^^^^^^^^^^ + +After the creation of the :py:class:`~helios.model.model.Model`, +:py:class:`~helios.data.datamodule.DataModule`, and the +:py:class:`~helios.trainer.Trainer`, we can create the plug-in and do the following: + +.. code-block:: python + + import helios.plugins.optuna as hlpo + import optuna + + def objective(trial: optuna.Trial) -> float: + model = ... + datamodule = ... + trainer = ... + + plugin = hlpo.OptunaPlugin(trial, "accuracy") + plugin.configure_trainer(trainer) + plugin.configure_model(model) + +The two ``configure_`` functions will do the following: + +#. Configure the trainer so the plug-in is registered into the plug-in table and ensure + that :py:class`optuna.TrialPruned`. +#. Configure the name of the model to allow cancelled trials to continue. Specifically, it + will append ``_trial-`` to the model name. + +.. note:: + The call to :py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` is + completely optional and only impacts the ability to resume trials. You may choose to + handle this yourself if it makes sense for your use-case. + +Using the Trial +^^^^^^^^^^^^^^^ + +The trial instance is held by the plugin and can be easily accessed through the trainer. +For example, we can use it to configure the layers in the classifier network within the +:py:meth:`~helios.model.model.Model.setup` function like this: + +.. code-block:: python + + def setup(self, fast_init: bool = False) -> None: + plugin = self.trainer.plugins[0] + assert isinstance(plugin, OptunaPlugin) + + # Assign the tunable parameters so we can log them as hyper-parameters when + # training ends. + self._tune_params["l1"] = plugin.trial.suggest_categorical( + "l1", [2**i for i in range(9)] + ) + self._tune_params["l2"] = plugin.trial.suggest_categorical( + "l2", [2**i for i in range(9)] + ) + self._tune_params["lr"] = plugin.trial.suggest_float("lr", 1e-4, 1e-1, log=True) + + self._net = Net( + l1=self._tune_params["l1"], # type: ignore[arg-type] + l2=self._tune_params["l2"], # type: ignore[arg-type] + ).to(self.device) + +Reporting Metrics +^^^^^^^^^^^^^^^^^ + +As the plug-in will automatically handle the reporting of metrics to the trial, it is +important for it to know which metric should be reported. This is accomplished by two +things: + +#. The :py:attr:`~helios.model.model.Model.metrics` table and +#. The value of ``metric_name`` in the constructor of + :py:class:`~helios.plugins.optuna.OptunaPlugin`. + +In order for the plug-in to work properly, the plug-in assumes that the ``metric_name`` +key exists in the :py:attr:`~helios.model.model.Model.metrics` table. If it doesn't, +nothing is reported to the trial. The plug-in will automatically handled distributed +training correctly, so there's no need for the model to do extra work. + +.. warning:: + In distributed training, it is your responsibility to ensure that the value of the + metric is correctly synced across processess (if applicable). + +Trial Pruning and Returning Metrics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The plug-in will automatically detect if a trial is pruned by optuna and gracefully +request that training end. The exact behaviour depends on whether training is distributed +or not. Specifically: + +* If training is not distributed, then the plug-in will raise a + :py:class:`optuna.TrialPruned` exception *after* calling + :py:meth:`~helios.model.model.Model.on_training_end` on the model. This ensures that if + any metrics are logged when training ends, they get logged if the trial is pruned. +* If training is distributed, then the plug-in requests that training terminate early. The + normal execution flow occurs when training is terminated early. Once the code exits the + :py:meth:`~helios.trainer.Trainer.fit` function, the user should call + :py:meth:`~helios.plugins.optuna.OptunaPlugin.check_pruned` to ensure that the + corresponding exception is correctly raised. + +In code, this can be handled as follows: + +.. code-block:: python + + def objective(trial: optuna.Trial) -> float: + ... + plugin.configure_trainer(trainer) + plugin.configure_model(model) + trainer.fit(model, datamodule) + plugin.check_pruned() + +To correctly return metrics, there are two cases that need to be handled. If training +isn't distributed, then the metrics can be grabbed directly from the +:py:attr:`~helios.model.model.Model.metrics` table. If training is distributed, then the +model needs to do a bit more work to ensure things get synchronized correctly. For our +example, we will place the synchronization of the metrics on +:py:meth:`~helios.model.model.Model.on_training_end`, but you may place it elsewhere if +it's convenient for you: + +.. code-block:: python + + def on_training_end(self) -> None: + ... + # Push the metrics we want to save into the multi-processing queue. + if self.is_distributed and self.rank == 0: + assert self.trainer.queue is not None + self.trainer.queue.put( + {"accuracy": accuracy, "loss": self._loss_items["loss"].item()} + ) + +The :py:attr:`~helios.trainer.Trainer.queue` ensures that the values get transferred to +the primary process. Once that's done, we just need to add the following to our +``objective`` function: + +.. code-block:: python + + def objective(trial: optuna.Trial) -> float: + ... + plugin.configure_trainer(trainer) + plugin.configure_model(model) + trainer.fit(model, datamodule) + plugin.check_pruned() + + if trainer.queue is None: + return model.metrics["accuracy"] + + metrics = trainer.queue.get() + return metrics["accuracy"] diff --git a/docs/source/quick-ref.rst b/docs/source/quick-ref.rst index bdac8a4..92aeb46 100644 --- a/docs/source/quick-ref.rst +++ b/docs/source/quick-ref.rst @@ -603,7 +603,7 @@ The order of the testing functions is identical to the one shown for validation: .. code-block:: python - model.eval()() + model.eval() model.on_testing_start() for batch in dataloader: model.on_testing_batch_start() @@ -611,3 +611,66 @@ The order of the testing functions is identical to the one shown for validation: model.on_testing_batch_end() model.on_testing_end() + +Exception Handling +================== + +By default, the main functions of :py:class:`~helios.trainer.Trainer` (those being +:py:meth:`~helios.trainer.Trainer.fit` and :py:meth:`~helios.trainer.Trainer.test`) will +automatically catch any unhandled exceptions and re-raise them as +:py:class:`RuntimeError`. Depending on the situation, it may be desirable for certain +exceptions to be passed through untouched. In order to accommodate this, the trainer has +two sets of lists of exception types: + +* :py:attr:`~helios.trainer.Trainer.train_exceptions` and +* :py:attr:`~helios.trainer.Trainer.test_exceptions`. + +If an exception is raised and said exception is found in the training list (for +:py:meth:`~helios.trainer.Trainer.fit`) or testing list (for +:py:meth:`~helios.trainer.Trainer.test`), then the exception is passed through unchanged. +Any other exceptions use the default behaviour. + +For example, suppose we had a custom exception called ``MyException`` and we wanted that +exception to be passed through when training because we're going to handle it ourselves. +We would then do the following: + +.. code-block:: python + + import helios.trainer as hlt + + trainer = hlt.Trainer(...) + trainer.train_esceptions.append(MyException) + + try: + trainer.fit(...) + except MyException as e: + ... + +The same logic applies for testing. This functionality is particularly useful when paired +with plug-ins. + +Synchronization +=============== + +Helios provides some synchronization wrappers found in the +:py:mod:`~helios.core.distributed` module: + +* :py:func:`~helios.core.distributed.gather_into_tensor`, +* :py:func:`~helios.core.distributed.all_reduce_tensors`. + +The trainer also provides another way to synchronize values through the multi-processing +queue. When using distributed training that isn't through ``torchrun``, Helios uses +``spawn`` to create the processes for each GPU. This triggers a copy of the arguments +passed in to the handler, which in this case are the trainer, model, and datamodule. This +presents a problem in the event that we need to return values back to the main process +once training is complete. To facilitate this task, the trainer will create a `queue +`__ that can +be accessed through :py:attr:`~helios.trainer.Trainer.queue`. + +.. note:: + If training isn't distributed or if it was started through ``torchrun``, then the + :py:attr:`~helios.trainer.Trainer.queue` is set to ``None``. + +The queue can then be used by either the :py:class:`~helios.model.model.Model`, the +:py:class:`~helios.data.datamodule.DataModule`, or any plug-in through their reference to +the trainer.