Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests][Doc] Remove LightningTrainer PBT Examples and Test #36476

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 8 additions & 101 deletions doc/source/tune/examples/tune-pytorch-lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -363,7 +364,13 @@
"In this example, we use an [Asynchronous Hyperband](https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/)\n",
"scheduler. This scheduler decides at each iteration which trials are likely to perform\n",
"badly, and stops these trials. This way we don't waste any resources on bad hyperparameter\n",
"configurations."
"configurations.\n",
"\n",
":::{note}\n",
"\n",
" Currently, `LightningTrainer` is not compatible with {class}`PopulationBasedTraining <ray.tune.schedulers.PopulationBasedTraining>` scheduler, which may mutate hyperparameters during training time. \n",
"\n",
":::"
]
},
{
Expand Down Expand Up @@ -498,106 +505,6 @@
"`layer_1_size=32`, `layer_2_size=64`, and `lr=0.000489046`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using Population Based Training to find the best parameters\n",
"\n",
"The `ASHAScheduler` terminates those trials early that show bad performance.\n",
"Sometimes, this stops trials that would get better after more training steps,\n",
"and which might eventually even show better performance than other configurations.\n",
"\n",
"Another popular method for hyperparameter tuning, called\n",
"[Population Based Training](https://deepmind.com/blog/article/population-based-training-neural-networks),\n",
"instead perturbs hyperparameters during the training run. Tune implements PBT, and\n",
"we only need to make some slight adjustments to our code."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def tune_mnist_pbt(num_samples=10):\n",
" # The range of hyperparameter perturbation.\n",
" lightning_config_mutations = (\n",
" LightningConfigBuilder()\n",
" .module(\n",
" config={\n",
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
" }\n",
" )\n",
" .build()\n",
" )\n",
"\n",
" # Create a PBT scheduler\n",
" scheduler = PopulationBasedTraining(\n",
" perturbation_interval=1,\n",
" time_attr=\"training_iteration\",\n",
" hyperparam_mutations={\"lightning_config\": lightning_config_mutations},\n",
" )\n",
"\n",
" tuner = tune.Tuner(\n",
" lightning_trainer,\n",
" param_space={\"lightning_config\": searchable_lightning_config},\n",
" tune_config=tune.TuneConfig(\n",
" metric=\"ptl/val_accuracy\",\n",
" mode=\"max\",\n",
" num_samples=num_samples,\n",
" scheduler=scheduler,\n",
" ),\n",
" run_config=air.RunConfig(\n",
" name=\"tune_mnist_pbt\",\n",
" ),\n",
" )\n",
" results = tuner.fit()\n",
" best_result = results.get_best_result(metric=\"ptl/val_accuracy\", mode=\"max\")\n",
" best_result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tune_mnist_pbt(num_samples=num_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example output of a run could look like this:\n",
"\n",
"```bash\n",
":emphasize-lines: 12\n",
"\n",
" +------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------+\n",
" | Trial name | status | loc | layer_1_size | layer_2_size | lr | loss | ptl/val_accuracy | training_iteration |\n",
" |------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------|\n",
" | LightningTrainer_85489_00000 | TERMINATED | | 64 | 64 | 0.0030@perturbed... | 0.108734 | 0.984954 | 5 |\n",
" | LightningTrainer_85489_00001 | TERMINATED | | 32 | 256 | 0.0010@perturbed... | 0.093577 | 0.983411 | 5 |\n",
" | LightningTrainer_85489_00002 | TERMINATED | | 128 | 64 | 0.0233@perturbed... | 0.0922348 | 0.983989 | 5 |\n",
" | LightningTrainer_85489_00003 | TERMINATED | | 64 | 128 | 0.0002@perturbed... | 0.124648 | 0.98206\t | 5 |\n",
" | LightningTrainer_85489_00004 | TERMINATED | | 128 | 256 | 0.0021 | 0.101717 | 0.993248 | 5 |\n",
" | LightningTrainer_85489_00005 | TERMINATED | | 32 | 128 | 0.0003@perturbed... | 0.121467 | 0.984182 | 5 |\n",
" | LightningTrainer_85489_00006 | TERMINATED | | 128 | 64 | 0.0020@perturbed... | 0.053446 | 0.984375 | 5 |\n",
" | LightningTrainer_85489_00007 | TERMINATED | | 64 | 64 | 0.0063@perturbed... | 0.129804 | 0.98669\t | 5 |\n",
" | LightningTrainer_85489_00008 | TERMINATED | | 128 | 256 | 0.0436@perturbed... | 0.363236 | 0.982253 | 5 |\n",
" | LightningTrainer_85489_00009 | TERMINATED | | 128 | 256 | 0.001 | 0.150946 | 0.985147 | 5 |\n",
" +------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------+\n",
"```\n",
"\n",
"As you can see, each sample ran the full number of 5 iterations.\n",
"All trials ended with quite good parameter combinations and showed relatively good performances (above `0.98`).\n",
"In some runs, the parameters have been perturbed. And the best configuration even reached a mean validation accuracy of `0.993248`!\n",
"\n",
"In summary, AIR LightningTrainer is easy to extend to use with Tune. It only required adding a few lines of code to integrate with Ray Tuner to get great performing parameter configurations."
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
8 changes: 2 additions & 6 deletions release/lightning_tests/workloads/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ray.tune as tune
from ray.air.config import CheckpointConfig, ScalingConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.schedulers import ASHAScheduler

from lightning_test_utils import MNISTClassifier, MNISTDataModule

Expand Down Expand Up @@ -65,11 +65,7 @@
metric="val_accuracy",
mode="max",
num_samples=2,
scheduler=PopulationBasedTraining(
time_attr="training_iteration",
hyperparam_mutations={"lightning_config": mutation_config},
perturbation_interval=1,
),
scheduler=ASHAScheduler(max_t=5, grace_period=1, reduction_factor=2),
),
)
results = tuner.fit()
Expand Down