diff --git a/README.md b/README.md index 877391c04..c72a59b6e 100644 --- a/README.md +++ b/README.md @@ -49,10 +49,7 @@ optimizer = load_optimizer(optimizer='adamp')(model.parameters()) # if you install `bitsandbytes` optimizer, you can use `8-bit` optimizers from `pytorch-optimizer`. -from pytorch_optimizer import load_optimizer - -opt = load_optimizer(optimizer='bnb_adamw8bit') -optimizer = opt(model.parameters()) +optimizer = load_optimizer(optimizer='bnb_adamw8bit')(model.parameters()) ``` Also, you can load the optimizer via `torch.hub`. @@ -61,6 +58,7 @@ Also, you can load the optimizer via `torch.hub`. import torch model = YourModel() + opt = torch.hub.load('kozistr/pytorch_optimizer', 'adamp') optimizer = opt(model.parameters()) ``` @@ -93,11 +91,13 @@ supported_optimizers = get_supported_optimizers() or you can also search them with the filter(s). ```python ->>> get_supported_optimizers('adam*') -['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw'] +from pytorch_optimizer import get_supported_optimizers ->>> get_supported_optimizers(['adam*', 'ranger*']) -['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw', 'ranger', 'ranger21'] +get_supported_optimizers('adam*') +# ['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw'] + +get_supported_optimizers(['adam*', 'ranger*']) +# ['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw', 'ranger', 'ranger21'] ``` | Optimizer | Description | Official Code | Paper | Citation | @@ -197,11 +197,13 @@ supported_lr_schedulers = get_supported_lr_schedulers() or you can also search them with the filter(s). ```python ->>> get_supported_lr_schedulers('cosine*') -['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup'] +from pytorch_optimizer import get_supported_lr_schedulers + +get_supported_lr_schedulers('cosine*') +# ['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup'] ->>> get_supported_lr_schedulers(['cosine*', '*warm*']) -['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup', 'warmup_stable_decay'] +get_supported_lr_schedulers(['cosine*', '*warm*']) +# ['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup', 'warmup_stable_decay'] ``` | LR Scheduler | Description | Official Code | Paper | Citation | @@ -224,11 +226,13 @@ supported_loss_functions = get_supported_loss_functions() or you can also search them with the filter(s). ```python ->>> get_supported_loss_functions('*focal*') -['bcefocalloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] +from pytorch_optimizer import get_supported_loss_functions + +get_supported_loss_functions('*focal*') +# ['bcefocalloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] ->>> get_supported_loss_functions(['*focal*', 'bce*']) -['bcefocalloss', 'bceloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] +get_supported_loss_functions(['*focal*', 'bce*']) +# ['bcefocalloss', 'bceloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] ``` | Loss Functions | Description | Official Code | Paper | Citation | diff --git a/docs/changelogs/v3.3.0.md b/docs/changelogs/v3.3.0.md index 814692061..7f0eb2962 100644 --- a/docs/changelogs/v3.3.0.md +++ b/docs/changelogs/v3.3.0.md @@ -8,3 +8,19 @@ * [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853) * Implement `FTRL` optimizer. (#291) * [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) + +### Refactor + +* Big refactoring, removing direct import from `pytorch_optimizer.*`. + * I removed some methods not to directly import from it from `pytorch_optimzier.*` because they're probably not used frequently and actually not an optimizer rather utils only used for specific optimizers. + * `pytorch_optimizer.[Shampoo stuff]` -> `pytorch_optimizer.optimizers.shampoo_utils.[Shampoo stuff]`. + * `shampoo_utils` like `Graft`, `BlockPartitioner`, `PreConditioner`, etc. You can check the details [here](https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/optimizer/shampoo_utils.py). + * `pytorch_optimizer.GaLoreProjector` -> `pytorch_optimizer.optimizers.galore.GaLoreProjector`. + * `pytorch_optimizer.gradfilter_ema` -> `pytorch_optimizer.optimizers.grokfast.gradfilter_ema`. + * `pytorch_optimizer.gradfilter_ma` -> `pytorch_optimizer.optimizers.grokfast.gradfilter_ma`. + * `pytorch_optimizer.l2_projection` -> `pytorch_optimizer.optimizers.alig.l2_projection`. + * `pytorch_optimizer.flatten_grad` -> `pytorch_optimizer.optimizers.pcgrad.flatten_grad`. + * `pytorch_optimizer.un_flatten_grad` -> `pytorch_optimizer.optimizers.pcgrad.un_flatten_grad`. + * `pytorch_optimizer.reduce_max_except_dim` -> `pytorch_optimizer.optimizers.sm3.reduce_max_except_dim`. + * `pytorch_optimizer.neuron_norm` -> `pytorch_optimizer.optimizers.nero.neuron_norm`. + * `pytorch_optimizer.neuron_mean` -> `pytorch_optimizer.optimizers.nero.neuron_mean`. diff --git a/docs/index.md b/docs/index.md index 877391c04..c72a59b6e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -49,10 +49,7 @@ optimizer = load_optimizer(optimizer='adamp')(model.parameters()) # if you install `bitsandbytes` optimizer, you can use `8-bit` optimizers from `pytorch-optimizer`. -from pytorch_optimizer import load_optimizer - -opt = load_optimizer(optimizer='bnb_adamw8bit') -optimizer = opt(model.parameters()) +optimizer = load_optimizer(optimizer='bnb_adamw8bit')(model.parameters()) ``` Also, you can load the optimizer via `torch.hub`. @@ -61,6 +58,7 @@ Also, you can load the optimizer via `torch.hub`. import torch model = YourModel() + opt = torch.hub.load('kozistr/pytorch_optimizer', 'adamp') optimizer = opt(model.parameters()) ``` @@ -93,11 +91,13 @@ supported_optimizers = get_supported_optimizers() or you can also search them with the filter(s). ```python ->>> get_supported_optimizers('adam*') -['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw'] +from pytorch_optimizer import get_supported_optimizers ->>> get_supported_optimizers(['adam*', 'ranger*']) -['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw', 'ranger', 'ranger21'] +get_supported_optimizers('adam*') +# ['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw'] + +get_supported_optimizers(['adam*', 'ranger*']) +# ['adamax', 'adamg', 'adammini', 'adamod', 'adamp', 'adams', 'adamw', 'ranger', 'ranger21'] ``` | Optimizer | Description | Official Code | Paper | Citation | @@ -197,11 +197,13 @@ supported_lr_schedulers = get_supported_lr_schedulers() or you can also search them with the filter(s). ```python ->>> get_supported_lr_schedulers('cosine*') -['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup'] +from pytorch_optimizer import get_supported_lr_schedulers + +get_supported_lr_schedulers('cosine*') +# ['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup'] ->>> get_supported_lr_schedulers(['cosine*', '*warm*']) -['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup', 'warmup_stable_decay'] +get_supported_lr_schedulers(['cosine*', '*warm*']) +# ['cosine', 'cosine_annealing', 'cosine_annealing_with_warm_restart', 'cosine_annealing_with_warmup', 'warmup_stable_decay'] ``` | LR Scheduler | Description | Official Code | Paper | Citation | @@ -224,11 +226,13 @@ supported_loss_functions = get_supported_loss_functions() or you can also search them with the filter(s). ```python ->>> get_supported_loss_functions('*focal*') -['bcefocalloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] +from pytorch_optimizer import get_supported_loss_functions + +get_supported_loss_functions('*focal*') +# ['bcefocalloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] ->>> get_supported_loss_functions(['*focal*', 'bce*']) -['bcefocalloss', 'bceloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] +get_supported_loss_functions(['*focal*', 'bce*']) +# ['bcefocalloss', 'bceloss', 'focalcosineloss', 'focalloss', 'focaltverskyloss'] ``` | Loss Functions | Description | Official Code | Paper | Citation | diff --git a/docs/loss.md b/docs/loss.md index 7de22f700..f0615b739 100644 --- a/docs/loss.md +++ b/docs/loss.md @@ -1,6 +1,6 @@ # Loss Function -::: pytorch_optimizer.loss.bi_tempered.bi_tempered_logistic_loss +::: pytorch_optimizer.bi_tempered_logistic_loss :docstring: ::: pytorch_optimizer.BiTemperedLogisticLoss @@ -35,7 +35,7 @@ :docstring: :members: -::: pytorch_optimizer.loss.jaccard.soft_jaccard_score +::: pytorch_optimizer.soft_jaccard_score :docstring: ::: pytorch_optimizer.JaccardLoss diff --git a/docs/optimizer.md b/docs/optimizer.md index f86501006..7d7f80cb1 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -1,5 +1,13 @@ # Optimizers +::: pytorch_optimizer.optimizer.create_optimizer + :docstring: + :members: + +::: pytorch_optimizer.optimizer.get_optimizer_parameters + :docstring: + :members: + ::: pytorch_optimizer.A2Grad :docstring: :members: @@ -168,10 +176,6 @@ :docstring: :members: -::: pytorch_optimizer.GaLoreProjector - :docstring: - :members: - ::: pytorch_optimizer.centralize_gradient :docstring: :members: @@ -180,14 +184,6 @@ :docstring: :members: -::: pytorch_optimizer.gradfilter_ema - :docstring: - :members: - -::: pytorch_optimizer.gradfilter_ma - :docstring: - :members: - ::: pytorch_optimizer.GrokFastAdamW :docstring: :members: diff --git a/docs/util.md b/docs/util.md index ece74f957..01b218bab 100644 --- a/docs/util.md +++ b/docs/util.md @@ -16,10 +16,6 @@ :docstring: :members: -::: pytorch_optimizer.optimizer.utils.get_optimizer_parameters - :docstring: - :members: - ::: pytorch_optimizer.optimizer.utils.is_valid_parameters :docstring: :members: @@ -36,14 +32,6 @@ :docstring: :members: -::: pytorch_optimizer.optimizer.utils.flatten_grad - :docstring: - :members: - -::: pytorch_optimizer.optimizer.utils.un_flatten_grad - :docstring: - :members: - ::: pytorch_optimizer.optimizer.utils.channel_view :docstring: :members: @@ -68,14 +56,6 @@ :docstring: :members: -::: pytorch_optimizer.optimizer.utils.neuron_norm - :docstring: - :members: - -::: pytorch_optimizer.optimizer.utils.neuron_mean - :docstring: - :members: - ::: pytorch_optimizer.optimizer.utils.disable_running_stats :docstring: :members: @@ -84,82 +64,10 @@ :docstring: :members: -::: pytorch_optimizer.optimizer.utils.l2_projection - :docstring: - :members: - ::: pytorch_optimizer.optimizer.utils.get_global_gradient_norm :docstring: :members: -::: pytorch_optimizer.optimizer.utils.reduce_max_except_dim - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.merge_small_dims - :docstring: - :members: - ::: pytorch_optimizer.optimizer.utils.reg_noise :docstring: :members: - -## Newton methods - -::: pytorch_optimizer.optimizer.shampoo_utils.power_iteration - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.compute_power_schur_newton - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.compute_power_svd - :docstring: - :members: - -## Grafting - -::: pytorch_optimizer.optimizer.shampoo_utils.Graft - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.LayerWiseGrafting - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.SGDGraft - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.SQRTNGraft - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.AdaGradGraft - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.RMSPropGraft - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.build_graft - :docstring: - :members: - -## Block Partitioner - -::: pytorch_optimizer.optimizer.shampoo_utils.BlockPartitioner - :docstring: - :members: - -## Pre-Conditioner - -::: pytorch_optimizer.optimizer.shampoo_utils.PreConditionerType - :docstring: - :members: - -::: pytorch_optimizer.optimizer.shampoo_utils.PreConditioner - :docstring: - :members: diff --git a/docs/visualizations/rastrigin_AdaBelief.png b/docs/visualizations/rastrigin_AdaBelief.png index 42aadc21c..10cdd18a1 100644 Binary files a/docs/visualizations/rastrigin_AdaBelief.png and b/docs/visualizations/rastrigin_AdaBelief.png differ diff --git a/docs/visualizations/rastrigin_AdaBound.png b/docs/visualizations/rastrigin_AdaBound.png index bddacffe1..b130e370f 100644 Binary files a/docs/visualizations/rastrigin_AdaBound.png and b/docs/visualizations/rastrigin_AdaBound.png differ diff --git a/docs/visualizations/rastrigin_AdaMod.png b/docs/visualizations/rastrigin_AdaMod.png index 57152c424..7ed8126d4 100644 Binary files a/docs/visualizations/rastrigin_AdaMod.png and b/docs/visualizations/rastrigin_AdaMod.png differ diff --git a/docs/visualizations/rastrigin_AdaPNM.png b/docs/visualizations/rastrigin_AdaPNM.png index 85c9d120f..33d82ea1a 100644 Binary files a/docs/visualizations/rastrigin_AdaPNM.png and b/docs/visualizations/rastrigin_AdaPNM.png differ diff --git a/docs/visualizations/rastrigin_Adai.png b/docs/visualizations/rastrigin_Adai.png index fb0f998e6..6b6a91994 100644 Binary files a/docs/visualizations/rastrigin_Adai.png and b/docs/visualizations/rastrigin_Adai.png differ diff --git a/docs/visualizations/rastrigin_AdamP.png b/docs/visualizations/rastrigin_AdamP.png index 6656eb9b7..4f72c0514 100644 Binary files a/docs/visualizations/rastrigin_AdamP.png and b/docs/visualizations/rastrigin_AdamP.png differ diff --git a/docs/visualizations/rastrigin_AdamW.png b/docs/visualizations/rastrigin_AdamW.png index b664f0748..e5dd0f336 100644 Binary files a/docs/visualizations/rastrigin_AdamW.png and b/docs/visualizations/rastrigin_AdamW.png differ diff --git a/docs/visualizations/rastrigin_Adan.png b/docs/visualizations/rastrigin_Adan.png index 74e8831e6..25e771978 100644 Binary files a/docs/visualizations/rastrigin_Adan.png and b/docs/visualizations/rastrigin_Adan.png differ diff --git a/docs/visualizations/rastrigin_AggMo.png b/docs/visualizations/rastrigin_AggMo.png index d56701d0c..04db32680 100644 Binary files a/docs/visualizations/rastrigin_AggMo.png and b/docs/visualizations/rastrigin_AggMo.png differ diff --git a/docs/visualizations/rastrigin_DAdaptAdaGrad.png b/docs/visualizations/rastrigin_DAdaptAdaGrad.png index ed2dacdff..c788eedc8 100644 Binary files a/docs/visualizations/rastrigin_DAdaptAdaGrad.png and b/docs/visualizations/rastrigin_DAdaptAdaGrad.png differ diff --git a/docs/visualizations/rastrigin_DAdaptAdam.png b/docs/visualizations/rastrigin_DAdaptAdam.png index ee051df95..bcbf45848 100644 Binary files a/docs/visualizations/rastrigin_DAdaptAdam.png and b/docs/visualizations/rastrigin_DAdaptAdam.png differ diff --git a/docs/visualizations/rastrigin_DAdaptSGD.png b/docs/visualizations/rastrigin_DAdaptSGD.png index f4466d0d2..860e1d852 100644 Binary files a/docs/visualizations/rastrigin_DAdaptSGD.png and b/docs/visualizations/rastrigin_DAdaptSGD.png differ diff --git a/docs/visualizations/rastrigin_DiffGrad.png b/docs/visualizations/rastrigin_DiffGrad.png index a04399270..9efb43ca8 100644 Binary files a/docs/visualizations/rastrigin_DiffGrad.png and b/docs/visualizations/rastrigin_DiffGrad.png differ diff --git a/docs/visualizations/rastrigin_Fromage.png b/docs/visualizations/rastrigin_Fromage.png index 58d2c4020..59f76e2a1 100644 Binary files a/docs/visualizations/rastrigin_Fromage.png and b/docs/visualizations/rastrigin_Fromage.png differ diff --git a/docs/visualizations/rastrigin_LARS.png b/docs/visualizations/rastrigin_LARS.png index af791fee4..5398b86b8 100644 Binary files a/docs/visualizations/rastrigin_LARS.png and b/docs/visualizations/rastrigin_LARS.png differ diff --git a/docs/visualizations/rastrigin_Lamb.png b/docs/visualizations/rastrigin_Lamb.png index 40f213162..c907f0507 100644 Binary files a/docs/visualizations/rastrigin_Lamb.png and b/docs/visualizations/rastrigin_Lamb.png differ diff --git a/docs/visualizations/rastrigin_MADGRAD.png b/docs/visualizations/rastrigin_MADGRAD.png index 14700b1d9..bde9eae15 100644 Binary files a/docs/visualizations/rastrigin_MADGRAD.png and b/docs/visualizations/rastrigin_MADGRAD.png differ diff --git a/docs/visualizations/rastrigin_MSVAG.png b/docs/visualizations/rastrigin_MSVAG.png index c0da0f1a4..ce2fa6721 100644 Binary files a/docs/visualizations/rastrigin_MSVAG.png and b/docs/visualizations/rastrigin_MSVAG.png differ diff --git a/docs/visualizations/rastrigin_Nero.png b/docs/visualizations/rastrigin_Nero.png index 6c3ef6e9b..557e9af4c 100644 Binary files a/docs/visualizations/rastrigin_Nero.png and b/docs/visualizations/rastrigin_Nero.png differ diff --git a/docs/visualizations/rastrigin_PID.png b/docs/visualizations/rastrigin_PID.png index e08de0a21..6f5d79de5 100644 Binary files a/docs/visualizations/rastrigin_PID.png and b/docs/visualizations/rastrigin_PID.png differ diff --git a/docs/visualizations/rastrigin_PNM.png b/docs/visualizations/rastrigin_PNM.png index dfe9ae920..46436025a 100644 Binary files a/docs/visualizations/rastrigin_PNM.png and b/docs/visualizations/rastrigin_PNM.png differ diff --git a/docs/visualizations/rastrigin_QHAdam.png b/docs/visualizations/rastrigin_QHAdam.png index 171e7ab1f..5ef038dac 100644 Binary files a/docs/visualizations/rastrigin_QHAdam.png and b/docs/visualizations/rastrigin_QHAdam.png differ diff --git a/docs/visualizations/rastrigin_QHM.png b/docs/visualizations/rastrigin_QHM.png index d1cdc4dd1..9a7a06932 100644 Binary files a/docs/visualizations/rastrigin_QHM.png and b/docs/visualizations/rastrigin_QHM.png differ diff --git a/docs/visualizations/rastrigin_RAdam.png b/docs/visualizations/rastrigin_RAdam.png index dc06b3266..214bc17a2 100644 Binary files a/docs/visualizations/rastrigin_RAdam.png and b/docs/visualizations/rastrigin_RAdam.png differ diff --git a/docs/visualizations/rastrigin_Ranger.png b/docs/visualizations/rastrigin_Ranger.png index 68ac151e1..731d8f4bb 100644 Binary files a/docs/visualizations/rastrigin_Ranger.png and b/docs/visualizations/rastrigin_Ranger.png differ diff --git a/docs/visualizations/rastrigin_Ranger21.png b/docs/visualizations/rastrigin_Ranger21.png index 43d799557..61128c2ae 100644 Binary files a/docs/visualizations/rastrigin_Ranger21.png and b/docs/visualizations/rastrigin_Ranger21.png differ diff --git a/docs/visualizations/rastrigin_SGDP.png b/docs/visualizations/rastrigin_SGDP.png index e6dc3a117..4a78a5930 100644 Binary files a/docs/visualizations/rastrigin_SGDP.png and b/docs/visualizations/rastrigin_SGDP.png differ diff --git a/docs/visualizations/rastrigin_ScalableShampoo.png b/docs/visualizations/rastrigin_ScalableShampoo.png index cde1599e8..87e23eebb 100644 Binary files a/docs/visualizations/rastrigin_ScalableShampoo.png and b/docs/visualizations/rastrigin_ScalableShampoo.png differ diff --git a/docs/visualizations/rastrigin_Shampoo.png b/docs/visualizations/rastrigin_Shampoo.png index 21467b61b..b978fb792 100644 Binary files a/docs/visualizations/rastrigin_Shampoo.png and b/docs/visualizations/rastrigin_Shampoo.png differ diff --git a/docs/visualizations/rosenbrock_ASGD.png b/docs/visualizations/rosenbrock_ASGD.png index 89b846fef..5eaee4f47 100644 Binary files a/docs/visualizations/rosenbrock_ASGD.png and b/docs/visualizations/rosenbrock_ASGD.png differ diff --git a/docs/visualizations/rosenbrock_AccSGD.png b/docs/visualizations/rosenbrock_AccSGD.png index 6ec976a75..a93eea8f1 100644 Binary files a/docs/visualizations/rosenbrock_AccSGD.png and b/docs/visualizations/rosenbrock_AccSGD.png differ diff --git a/docs/visualizations/rosenbrock_AdaBelief.png b/docs/visualizations/rosenbrock_AdaBelief.png index 8274285b3..f78d554d7 100644 Binary files a/docs/visualizations/rosenbrock_AdaBelief.png and b/docs/visualizations/rosenbrock_AdaBelief.png differ diff --git a/docs/visualizations/rosenbrock_AdaBound.png b/docs/visualizations/rosenbrock_AdaBound.png index 88d8e9074..feb2fdb77 100644 Binary files a/docs/visualizations/rosenbrock_AdaBound.png and b/docs/visualizations/rosenbrock_AdaBound.png differ diff --git a/docs/visualizations/rosenbrock_AdaDelta.png b/docs/visualizations/rosenbrock_AdaDelta.png index b5131e082..ce82532ce 100644 Binary files a/docs/visualizations/rosenbrock_AdaDelta.png and b/docs/visualizations/rosenbrock_AdaDelta.png differ diff --git a/docs/visualizations/rosenbrock_AdaFactor.png b/docs/visualizations/rosenbrock_AdaFactor.png index cf838b0a1..ded574035 100644 Binary files a/docs/visualizations/rosenbrock_AdaFactor.png and b/docs/visualizations/rosenbrock_AdaFactor.png differ diff --git a/docs/visualizations/rosenbrock_AdaHessian.png b/docs/visualizations/rosenbrock_AdaHessian.png index a892df181..c8c16e038 100644 Binary files a/docs/visualizations/rosenbrock_AdaHessian.png and b/docs/visualizations/rosenbrock_AdaHessian.png differ diff --git a/docs/visualizations/rosenbrock_AdaMax.png b/docs/visualizations/rosenbrock_AdaMax.png index 66c42896f..b9d248794 100644 Binary files a/docs/visualizations/rosenbrock_AdaMax.png and b/docs/visualizations/rosenbrock_AdaMax.png differ diff --git a/docs/visualizations/rosenbrock_AdaMod.png b/docs/visualizations/rosenbrock_AdaMod.png index 02755c0e7..c92ab0970 100644 Binary files a/docs/visualizations/rosenbrock_AdaMod.png and b/docs/visualizations/rosenbrock_AdaMod.png differ diff --git a/docs/visualizations/rosenbrock_AdaNorm.png b/docs/visualizations/rosenbrock_AdaNorm.png index fe401e422..09f2fa6ff 100644 Binary files a/docs/visualizations/rosenbrock_AdaNorm.png and b/docs/visualizations/rosenbrock_AdaNorm.png differ diff --git a/docs/visualizations/rosenbrock_AdaPNM.png b/docs/visualizations/rosenbrock_AdaPNM.png index e65621109..f8eaa5017 100644 Binary files a/docs/visualizations/rosenbrock_AdaPNM.png and b/docs/visualizations/rosenbrock_AdaPNM.png differ diff --git a/docs/visualizations/rosenbrock_AdaSmooth.png b/docs/visualizations/rosenbrock_AdaSmooth.png index 434fd5552..d8d056073 100644 Binary files a/docs/visualizations/rosenbrock_AdaSmooth.png and b/docs/visualizations/rosenbrock_AdaSmooth.png differ diff --git a/docs/visualizations/rosenbrock_Adai.png b/docs/visualizations/rosenbrock_Adai.png index 32d3e7c72..ab1066b0a 100644 Binary files a/docs/visualizations/rosenbrock_Adai.png and b/docs/visualizations/rosenbrock_Adai.png differ diff --git a/docs/visualizations/rosenbrock_Adalite.png b/docs/visualizations/rosenbrock_Adalite.png index 87e7ac911..15b4b3eb7 100644 Binary files a/docs/visualizations/rosenbrock_Adalite.png and b/docs/visualizations/rosenbrock_Adalite.png differ diff --git a/docs/visualizations/rosenbrock_Adam.png b/docs/visualizations/rosenbrock_Adam.png index 3a60f0d87..07e89828d 100644 Binary files a/docs/visualizations/rosenbrock_Adam.png and b/docs/visualizations/rosenbrock_Adam.png differ diff --git a/docs/visualizations/rosenbrock_AdamP.png b/docs/visualizations/rosenbrock_AdamP.png index e16f9d850..27916cff4 100644 Binary files a/docs/visualizations/rosenbrock_AdamP.png and b/docs/visualizations/rosenbrock_AdamP.png differ diff --git a/docs/visualizations/rosenbrock_AdamS.png b/docs/visualizations/rosenbrock_AdamS.png index dd2472aa6..b5c25c7a9 100644 Binary files a/docs/visualizations/rosenbrock_AdamS.png and b/docs/visualizations/rosenbrock_AdamS.png differ diff --git a/docs/visualizations/rosenbrock_AdamW.png b/docs/visualizations/rosenbrock_AdamW.png index 41e3fab6e..3e8a25f12 100644 Binary files a/docs/visualizations/rosenbrock_AdamW.png and b/docs/visualizations/rosenbrock_AdamW.png differ diff --git a/docs/visualizations/rosenbrock_Adan.png b/docs/visualizations/rosenbrock_Adan.png index e10dd80d4..71a5e0ed4 100644 Binary files a/docs/visualizations/rosenbrock_Adan.png and b/docs/visualizations/rosenbrock_Adan.png differ diff --git a/docs/visualizations/rosenbrock_AggMo.png b/docs/visualizations/rosenbrock_AggMo.png index 723d233bd..3826e1472 100644 Binary files a/docs/visualizations/rosenbrock_AggMo.png and b/docs/visualizations/rosenbrock_AggMo.png differ diff --git a/docs/visualizations/rosenbrock_Aida.png b/docs/visualizations/rosenbrock_Aida.png index 35c0c0966..e0238f174 100644 Binary files a/docs/visualizations/rosenbrock_Aida.png and b/docs/visualizations/rosenbrock_Aida.png differ diff --git a/docs/visualizations/rosenbrock_Amos.png b/docs/visualizations/rosenbrock_Amos.png index 761c75258..5a5afc338 100644 Binary files a/docs/visualizations/rosenbrock_Amos.png and b/docs/visualizations/rosenbrock_Amos.png differ diff --git a/docs/visualizations/rosenbrock_Apollo.png b/docs/visualizations/rosenbrock_Apollo.png index c645a8076..e73b79052 100644 Binary files a/docs/visualizations/rosenbrock_Apollo.png and b/docs/visualizations/rosenbrock_Apollo.png differ diff --git a/docs/visualizations/rosenbrock_AvaGrad.png b/docs/visualizations/rosenbrock_AvaGrad.png index 8a7372f22..b59d634d9 100644 Binary files a/docs/visualizations/rosenbrock_AvaGrad.png and b/docs/visualizations/rosenbrock_AvaGrad.png differ diff --git a/docs/visualizations/rosenbrock_CAME.png b/docs/visualizations/rosenbrock_CAME.png index 9a4373c7d..ccca7bc76 100644 Binary files a/docs/visualizations/rosenbrock_CAME.png and b/docs/visualizations/rosenbrock_CAME.png differ diff --git a/docs/visualizations/rosenbrock_DAdaptAdaGrad.png b/docs/visualizations/rosenbrock_DAdaptAdaGrad.png index c54afc45e..c5cc43be0 100644 Binary files a/docs/visualizations/rosenbrock_DAdaptAdaGrad.png and b/docs/visualizations/rosenbrock_DAdaptAdaGrad.png differ diff --git a/docs/visualizations/rosenbrock_DAdaptAdam.png b/docs/visualizations/rosenbrock_DAdaptAdam.png index 66333fb27..5bb93391b 100644 Binary files a/docs/visualizations/rosenbrock_DAdaptAdam.png and b/docs/visualizations/rosenbrock_DAdaptAdam.png differ diff --git a/docs/visualizations/rosenbrock_DAdaptAdan.png b/docs/visualizations/rosenbrock_DAdaptAdan.png index 5504a19a3..f5533cdda 100644 Binary files a/docs/visualizations/rosenbrock_DAdaptAdan.png and b/docs/visualizations/rosenbrock_DAdaptAdan.png differ diff --git a/docs/visualizations/rosenbrock_DAdaptLion.png b/docs/visualizations/rosenbrock_DAdaptLion.png index a7e20a87f..a861a0c8b 100644 Binary files a/docs/visualizations/rosenbrock_DAdaptLion.png and b/docs/visualizations/rosenbrock_DAdaptLion.png differ diff --git a/docs/visualizations/rosenbrock_DAdaptSGD.png b/docs/visualizations/rosenbrock_DAdaptSGD.png index aac238d3d..cb4c91310 100644 Binary files a/docs/visualizations/rosenbrock_DAdaptSGD.png and b/docs/visualizations/rosenbrock_DAdaptSGD.png differ diff --git a/docs/visualizations/rosenbrock_DiffGrad.png b/docs/visualizations/rosenbrock_DiffGrad.png index e26f1d006..a3022c5e6 100644 Binary files a/docs/visualizations/rosenbrock_DiffGrad.png and b/docs/visualizations/rosenbrock_DiffGrad.png differ diff --git a/docs/visualizations/rosenbrock_FAdam.png b/docs/visualizations/rosenbrock_FAdam.png index 060f69c27..4d271f3fd 100644 Binary files a/docs/visualizations/rosenbrock_FAdam.png and b/docs/visualizations/rosenbrock_FAdam.png differ diff --git a/docs/visualizations/rosenbrock_Fromage.png b/docs/visualizations/rosenbrock_Fromage.png index 0679dd94d..3da7a6652 100644 Binary files a/docs/visualizations/rosenbrock_Fromage.png and b/docs/visualizations/rosenbrock_Fromage.png differ diff --git a/docs/visualizations/rosenbrock_GaLore.png b/docs/visualizations/rosenbrock_GaLore.png index 2c1cc51ca..4c697188d 100644 Binary files a/docs/visualizations/rosenbrock_GaLore.png and b/docs/visualizations/rosenbrock_GaLore.png differ diff --git a/docs/visualizations/rosenbrock_Gravity.png b/docs/visualizations/rosenbrock_Gravity.png index bfecb5601..ee3153529 100644 Binary files a/docs/visualizations/rosenbrock_Gravity.png and b/docs/visualizations/rosenbrock_Gravity.png differ diff --git a/docs/visualizations/rosenbrock_GrokFastAdamW.png b/docs/visualizations/rosenbrock_GrokFastAdamW.png index b094e88c7..6b3b84b66 100644 Binary files a/docs/visualizations/rosenbrock_GrokFastAdamW.png and b/docs/visualizations/rosenbrock_GrokFastAdamW.png differ diff --git a/docs/visualizations/rosenbrock_Kate.png b/docs/visualizations/rosenbrock_Kate.png index 0710dbf45..7801108d6 100644 Binary files a/docs/visualizations/rosenbrock_Kate.png and b/docs/visualizations/rosenbrock_Kate.png differ diff --git a/docs/visualizations/rosenbrock_LARS.png b/docs/visualizations/rosenbrock_LARS.png index 226becfcc..57e66a795 100644 Binary files a/docs/visualizations/rosenbrock_LARS.png and b/docs/visualizations/rosenbrock_LARS.png differ diff --git a/docs/visualizations/rosenbrock_Lamb.png b/docs/visualizations/rosenbrock_Lamb.png index 0dd83537e..01224ae90 100644 Binary files a/docs/visualizations/rosenbrock_Lamb.png and b/docs/visualizations/rosenbrock_Lamb.png differ diff --git a/docs/visualizations/rosenbrock_Lion.png b/docs/visualizations/rosenbrock_Lion.png index 2a27e5656..85a53e4a5 100644 Binary files a/docs/visualizations/rosenbrock_Lion.png and b/docs/visualizations/rosenbrock_Lion.png differ diff --git a/docs/visualizations/rosenbrock_MADGRAD.png b/docs/visualizations/rosenbrock_MADGRAD.png index c862402e4..1ebad0494 100644 Binary files a/docs/visualizations/rosenbrock_MADGRAD.png and b/docs/visualizations/rosenbrock_MADGRAD.png differ diff --git a/docs/visualizations/rosenbrock_MSVAG.png b/docs/visualizations/rosenbrock_MSVAG.png index bbbe7561b..70064cd96 100644 Binary files a/docs/visualizations/rosenbrock_MSVAG.png and b/docs/visualizations/rosenbrock_MSVAG.png differ diff --git a/docs/visualizations/rosenbrock_Nero.png b/docs/visualizations/rosenbrock_Nero.png index 72644c68e..f3c738138 100644 Binary files a/docs/visualizations/rosenbrock_Nero.png and b/docs/visualizations/rosenbrock_Nero.png differ diff --git a/docs/visualizations/rosenbrock_NovoGrad.png b/docs/visualizations/rosenbrock_NovoGrad.png index b7186e5dc..d00272c10 100644 Binary files a/docs/visualizations/rosenbrock_NovoGrad.png and b/docs/visualizations/rosenbrock_NovoGrad.png differ diff --git a/docs/visualizations/rosenbrock_PAdam.png b/docs/visualizations/rosenbrock_PAdam.png index 191c5f7d6..584570bf7 100644 Binary files a/docs/visualizations/rosenbrock_PAdam.png and b/docs/visualizations/rosenbrock_PAdam.png differ diff --git a/docs/visualizations/rosenbrock_PID.png b/docs/visualizations/rosenbrock_PID.png index cebfdd45b..9037e7595 100644 Binary files a/docs/visualizations/rosenbrock_PID.png and b/docs/visualizations/rosenbrock_PID.png differ diff --git a/docs/visualizations/rosenbrock_PNM.png b/docs/visualizations/rosenbrock_PNM.png index a5f61f572..a827b7f86 100644 Binary files a/docs/visualizations/rosenbrock_PNM.png and b/docs/visualizations/rosenbrock_PNM.png differ diff --git a/docs/visualizations/rosenbrock_Prodigy.png b/docs/visualizations/rosenbrock_Prodigy.png index bfd6ce531..124ab0b7e 100644 Binary files a/docs/visualizations/rosenbrock_Prodigy.png and b/docs/visualizations/rosenbrock_Prodigy.png differ diff --git a/docs/visualizations/rosenbrock_QHAdam.png b/docs/visualizations/rosenbrock_QHAdam.png index 3e8471195..9429a1f3b 100644 Binary files a/docs/visualizations/rosenbrock_QHAdam.png and b/docs/visualizations/rosenbrock_QHAdam.png differ diff --git a/docs/visualizations/rosenbrock_QHM.png b/docs/visualizations/rosenbrock_QHM.png index 8fe1e131e..b1065751c 100644 Binary files a/docs/visualizations/rosenbrock_QHM.png and b/docs/visualizations/rosenbrock_QHM.png differ diff --git a/docs/visualizations/rosenbrock_RAdam.png b/docs/visualizations/rosenbrock_RAdam.png index 7de6b4568..8348dcdd7 100644 Binary files a/docs/visualizations/rosenbrock_RAdam.png and b/docs/visualizations/rosenbrock_RAdam.png differ diff --git a/docs/visualizations/rosenbrock_Ranger.png b/docs/visualizations/rosenbrock_Ranger.png index e21c38d91..83d1cec9a 100644 Binary files a/docs/visualizations/rosenbrock_Ranger.png and b/docs/visualizations/rosenbrock_Ranger.png differ diff --git a/docs/visualizations/rosenbrock_Ranger21.png b/docs/visualizations/rosenbrock_Ranger21.png index f7929696b..e47f9a1ce 100644 Binary files a/docs/visualizations/rosenbrock_Ranger21.png and b/docs/visualizations/rosenbrock_Ranger21.png differ diff --git a/docs/visualizations/rosenbrock_SGD.png b/docs/visualizations/rosenbrock_SGD.png index b72d6a198..ed27a135a 100644 Binary files a/docs/visualizations/rosenbrock_SGD.png and b/docs/visualizations/rosenbrock_SGD.png differ diff --git a/docs/visualizations/rosenbrock_SGDP.png b/docs/visualizations/rosenbrock_SGDP.png index cccfb9e2f..79aeecc12 100644 Binary files a/docs/visualizations/rosenbrock_SGDP.png and b/docs/visualizations/rosenbrock_SGDP.png differ diff --git a/docs/visualizations/rosenbrock_SGDW.png b/docs/visualizations/rosenbrock_SGDW.png index 1ed5d46c7..07c789938 100644 Binary files a/docs/visualizations/rosenbrock_SGDW.png and b/docs/visualizations/rosenbrock_SGDW.png differ diff --git a/docs/visualizations/rosenbrock_SM3.png b/docs/visualizations/rosenbrock_SM3.png index 40c678e0c..abdf20039 100644 Binary files a/docs/visualizations/rosenbrock_SM3.png and b/docs/visualizations/rosenbrock_SM3.png differ diff --git a/docs/visualizations/rosenbrock_SRMM.png b/docs/visualizations/rosenbrock_SRMM.png index a37498810..70eeb30a0 100644 Binary files a/docs/visualizations/rosenbrock_SRMM.png and b/docs/visualizations/rosenbrock_SRMM.png differ diff --git a/docs/visualizations/rosenbrock_SWATS.png b/docs/visualizations/rosenbrock_SWATS.png index 050768b26..39906d87f 100644 Binary files a/docs/visualizations/rosenbrock_SWATS.png and b/docs/visualizations/rosenbrock_SWATS.png differ diff --git a/docs/visualizations/rosenbrock_ScalableShampoo.png b/docs/visualizations/rosenbrock_ScalableShampoo.png index aede877c1..827518468 100644 Binary files a/docs/visualizations/rosenbrock_ScalableShampoo.png and b/docs/visualizations/rosenbrock_ScalableShampoo.png differ diff --git a/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png b/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png index 77fe4c09d..14894bb32 100644 Binary files a/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png and b/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png differ diff --git a/docs/visualizations/rosenbrock_ScheduleFreeSGD.png b/docs/visualizations/rosenbrock_ScheduleFreeSGD.png index 7e5de0d1f..194b530bf 100644 Binary files a/docs/visualizations/rosenbrock_ScheduleFreeSGD.png and b/docs/visualizations/rosenbrock_ScheduleFreeSGD.png differ diff --git a/docs/visualizations/rosenbrock_Shampoo.png b/docs/visualizations/rosenbrock_Shampoo.png index 01fb11213..87775b525 100644 Binary files a/docs/visualizations/rosenbrock_Shampoo.png and b/docs/visualizations/rosenbrock_Shampoo.png differ diff --git a/docs/visualizations/rosenbrock_SignSGD.png b/docs/visualizations/rosenbrock_SignSGD.png index b2ff3a776..7cb500468 100644 Binary files a/docs/visualizations/rosenbrock_SignSGD.png and b/docs/visualizations/rosenbrock_SignSGD.png differ diff --git a/docs/visualizations/rosenbrock_SophiaH.png b/docs/visualizations/rosenbrock_SophiaH.png index 3b3e8c1b3..a7f0a1a3c 100644 Binary files a/docs/visualizations/rosenbrock_SophiaH.png and b/docs/visualizations/rosenbrock_SophiaH.png differ diff --git a/docs/visualizations/rosenbrock_StableAdamW.png b/docs/visualizations/rosenbrock_StableAdamW.png index 29ee96a91..560ed6fa5 100644 Binary files a/docs/visualizations/rosenbrock_StableAdamW.png and b/docs/visualizations/rosenbrock_StableAdamW.png differ diff --git a/docs/visualizations/rosenbrock_Tiger.png b/docs/visualizations/rosenbrock_Tiger.png index c04a58175..d830a4600 100644 Binary files a/docs/visualizations/rosenbrock_Tiger.png and b/docs/visualizations/rosenbrock_Tiger.png differ diff --git a/docs/visualizations/rosenbrock_Yogi.png b/docs/visualizations/rosenbrock_Yogi.png index ab59de064..5365c6304 100644 Binary files a/docs/visualizations/rosenbrock_Yogi.png and b/docs/visualizations/rosenbrock_Yogi.png differ diff --git a/examples/visualize_optimizers.py b/examples/visualize_optimizers.py index e13f8e991..44c41925e 100644 --- a/examples/visualize_optimizers.py +++ b/examples/visualize_optimizers.py @@ -1,13 +1,13 @@ import math from pathlib import Path -import hyperopt.exceptions import numpy as np import torch from hyperopt import fmin, hp, tpe +from hyperopt.exceptions import AllTrialsFailed from matplotlib import pyplot as plt -from pytorch_optimizer import OPTIMIZERS +from pytorch_optimizer.optimizer import OPTIMIZERS def rosenbrock(tensors) -> torch.Tensor: @@ -57,17 +57,17 @@ def objective_rastrigin(params, minimum=(0, 0)): return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2 -def objective_rosenbrok(params, minimum=(1.0, 1.0)): +def objective_rosenbrock(params, minimum=(1.0, 1.0)): steps = execute_steps(rastrigin, (-2.0, 2.0), params['optimizer_class'], {'lr': params['lr']}, 100) return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2 -def plot_rastrigin(grad_iter, optimizer_name, lr) -> None: +def plot_rastrigin(grad_iter, optimizer_plot_path, optimizer_name, lr) -> None: x = torch.linspace(-4.5, 4.5, 250) y = torch.linspace(-4.5, 4.5, 250) - x, y = torch.meshgrid(x, y) + x, y = torch.meshgrid(x, y, indexing='ij') z = rastrigin([x, y]) iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] @@ -81,15 +81,15 @@ def plot_rastrigin(grad_iter, optimizer_name, lr) -> None: plt.plot(0, 0, 'gD') plt.plot(iter_x[-1], iter_y[-1], 'rD') - plt.savefig(f'../docs/visualizations/rastrigin_{optimizer_name}.png') + plt.savefig(str(optimizer_plot_path)) plt.close() -def plot_rosenbrok(grad_iter, optimizer_name, lr): +def plot_rosenbrok(grad_iter, optimizer_plot_path, optimizer_name, lr): x = torch.linspace(-2, 2, 250) y = torch.linspace(-1, 3, 250) - x, y = torch.meshgrid(x, y) + x, y = torch.meshgrid(x, y, indexing='ij') z = rosenbrock([x, y]) iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] @@ -103,7 +103,7 @@ def plot_rosenbrok(grad_iter, optimizer_name, lr): ax.set_title(f'Rosenbrock func: {optimizer_name} with {len(iter_x)} iterations, lr={lr:.6f}') plt.plot(1.0, 1.0, 'gD') plt.plot(iter_x[-1], iter_y[-1], 'rD') - plt.savefig(f'../docs/visualizations/rosenbrock_{optimizer_name}.png') + plt.savefig(str(optimizer_plot_path)) plt.close() @@ -113,7 +113,8 @@ def execute_experiments( for item in optimizers: optimizer_class, lr_low, lr_hi = item - if (root_path / f'{exp_name}_{optimizer_class.__name__}.png').exists(): + optimizer_plot_path = root_path / f'{exp_name}_{optimizer_class.__name__}.png' + if optimizer_plot_path.exists(): continue space = { @@ -129,7 +130,7 @@ def execute_experiments( max_evals=200, rstate=np.random.default_rng(seed), ) - except hyperopt.exceptions.AllTrialsFailed: + except AllTrialsFailed: continue steps = execute_steps( @@ -140,7 +141,7 @@ def execute_experiments( 500, ) - plot_func(steps, optimizer_class.__name__, best['lr']) + plot_func(steps, optimizer_plot_path, optimizer_class.__name__, best['lr']) def main(): @@ -149,19 +150,14 @@ def main(): np.random.seed(42) torch.manual_seed(42) - root_path = Path('..') / 'docs' / 'visualizations' + root_path = Path('.') / 'docs' / 'visualizations' optimizers = [ - (torch.optim.AdamW, -6, 0.5), - (torch.optim.Adam, -6, 0.5), - (torch.optim.SGD, -6, -1.0), + (optimizer, -6, 0.5) + for optimizer_name, optimizer in OPTIMIZERS.items() + if optimizer_name.lower() not in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini'} ] - - for optimizer_name, optimizer in OPTIMIZERS.items(): - if optimizer_name.lower() in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini'}: - continue - - optimizers.append((optimizer, -6, 0.2)) + optimizers.extend([(torch.optim.AdamW, -6, 0.5), (torch.optim.Adam, -6, 0.5), (torch.optim.SGD, -6, -1.0)]) execute_experiments( optimizers, @@ -175,12 +171,12 @@ def main(): execute_experiments( optimizers, - objective_rosenbrok, + objective_rosenbrock, rosenbrock, plot_rosenbrok, (-2.0, 2.0), root_path, - 'rosenbrok', + 'rosenbrock', ) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 027a74602..f1b4947be 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -1,475 +1,149 @@ # ruff: noqa -import fnmatch -from importlib.util import find_spec -from typing import Dict, List, Optional, Sequence, Set, Union - -import torch.cuda -from torch import nn -from torch.optim import AdamW - -from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER -from pytorch_optimizer.loss.bi_tempered import BinaryBiTemperedLogisticLoss, BiTemperedLogisticLoss -from pytorch_optimizer.loss.cross_entropy import BCELoss -from pytorch_optimizer.loss.dice import DiceLoss, soft_dice_score -from pytorch_optimizer.loss.f1 import SoftF1Loss -from pytorch_optimizer.loss.focal import BCEFocalLoss, FocalCosineLoss, FocalLoss, FocalTverskyLoss -from pytorch_optimizer.loss.jaccard import JaccardLoss, soft_jaccard_score -from pytorch_optimizer.loss.ldam import LDAMLoss -from pytorch_optimizer.loss.lovasz import LovaszHingeLoss -from pytorch_optimizer.loss.tversky import TverskyLoss +from pytorch_optimizer.loss import ( + BCEFocalLoss, + BCELoss, + BinaryBiTemperedLogisticLoss, + BiTemperedLogisticLoss, + DiceLoss, + FocalCosineLoss, + FocalLoss, + FocalTverskyLoss, + JaccardLoss, + LDAMLoss, + LovaszHingeLoss, + SoftF1Loss, + TverskyLoss, + bi_tempered_logistic_loss, + get_supported_loss_functions, + soft_dice_score, + soft_jaccard_score, +) from pytorch_optimizer.lr_scheduler import ( ConstantLR, CosineAnnealingLR, CosineAnnealingWarmRestarts, + CosineAnnealingWarmupRestarts, + CosineScheduler, CyclicLR, + LinearScheduler, MultiplicativeLR, MultiStepLR, OneCycleLR, - SchedulerType, + PolyScheduler, + ProportionScheduler, + REXScheduler, StepLR, + deberta_v3_large_lr_scheduler, + get_chebyshev_perm_steps, + get_chebyshev_schedule, + get_supported_lr_schedulers, + get_wsd_schedule, + load_lr_scheduler, ) -from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_perm_steps, get_chebyshev_schedule -from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts -from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler -from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler -from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler -from pytorch_optimizer.lr_scheduler.rex import REXScheduler -from pytorch_optimizer.lr_scheduler.wsd import get_wsd_schedule -from pytorch_optimizer.optimizer.a2grad import A2Grad -from pytorch_optimizer.optimizer.adabelief import AdaBelief -from pytorch_optimizer.optimizer.adabound import AdaBound -from pytorch_optimizer.optimizer.adadelta import AdaDelta -from pytorch_optimizer.optimizer.adafactor import AdaFactor -from pytorch_optimizer.optimizer.adahessian import AdaHessian -from pytorch_optimizer.optimizer.adai import Adai -from pytorch_optimizer.optimizer.adalite import Adalite -from pytorch_optimizer.optimizer.adam_mini import AdamMini -from pytorch_optimizer.optimizer.adamax import AdaMax -from pytorch_optimizer.optimizer.adamg import AdamG -from pytorch_optimizer.optimizer.adamod import AdaMod -from pytorch_optimizer.optimizer.adamp import AdamP -from pytorch_optimizer.optimizer.adams import AdamS -from pytorch_optimizer.optimizer.adamw import StableAdamW -from pytorch_optimizer.optimizer.adan import Adan -from pytorch_optimizer.optimizer.adanorm import AdaNorm -from pytorch_optimizer.optimizer.adapnm import AdaPNM -from pytorch_optimizer.optimizer.adashift import AdaShift -from pytorch_optimizer.optimizer.adasmooth import AdaSmooth -from pytorch_optimizer.optimizer.ademamix import AdEMAMix -from pytorch_optimizer.optimizer.adopt import ADOPT -from pytorch_optimizer.optimizer.agc import agc -from pytorch_optimizer.optimizer.aggmo import AggMo -from pytorch_optimizer.optimizer.aida import Aida -from pytorch_optimizer.optimizer.alig import AliG -from pytorch_optimizer.optimizer.amos import Amos -from pytorch_optimizer.optimizer.apollo import Apollo -from pytorch_optimizer.optimizer.avagrad import AvaGrad -from pytorch_optimizer.optimizer.came import CAME -from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptLion, DAdaptSGD -from pytorch_optimizer.optimizer.diffgrad import DiffGrad -from pytorch_optimizer.optimizer.fadam import FAdam -from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer -from pytorch_optimizer.optimizer.fromage import Fromage -from pytorch_optimizer.optimizer.ftrl import FTRL -from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector -from pytorch_optimizer.optimizer.gc import centralize_gradient -from pytorch_optimizer.optimizer.gravity import Gravity -from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW, gradfilter_ema, gradfilter_ma -from pytorch_optimizer.optimizer.kate import Kate -from pytorch_optimizer.optimizer.lamb import Lamb -from pytorch_optimizer.optimizer.lars import LARS -from pytorch_optimizer.optimizer.lion import Lion -from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO -from pytorch_optimizer.optimizer.lookahead import Lookahead -from pytorch_optimizer.optimizer.madgrad import MADGRAD -from pytorch_optimizer.optimizer.msvag import MSVAG -from pytorch_optimizer.optimizer.nero import Nero -from pytorch_optimizer.optimizer.novograd import NovoGrad -from pytorch_optimizer.optimizer.padam import PAdam -from pytorch_optimizer.optimizer.pcgrad import PCGrad -from pytorch_optimizer.optimizer.pid import PID -from pytorch_optimizer.optimizer.pnm import PNM -from pytorch_optimizer.optimizer.prodigy import Prodigy -from pytorch_optimizer.optimizer.qhadam import QHAdam -from pytorch_optimizer.optimizer.qhm import QHM -from pytorch_optimizer.optimizer.radam import RAdam -from pytorch_optimizer.optimizer.ranger import Ranger -from pytorch_optimizer.optimizer.ranger21 import Ranger21 -from pytorch_optimizer.optimizer.rotograd import RotoGrad -from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM -from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeSGD -from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD -from pytorch_optimizer.optimizer.sgdp import SGDP -from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo -from pytorch_optimizer.optimizer.shampoo_utils import ( - AdaGradGraft, - BlockPartitioner, - Graft, - LayerWiseGrafting, - PreConditioner, - PreConditionerType, - RMSPropGraft, - SGDGraft, - SQRTNGraft, - compute_power_schur_newton, - compute_power_svd, - merge_small_dims, - power_iteration, -) -from pytorch_optimizer.optimizer.sm3 import SM3 -from pytorch_optimizer.optimizer.soap import SOAP -from pytorch_optimizer.optimizer.sophia import SophiaH -from pytorch_optimizer.optimizer.srmm import SRMM -from pytorch_optimizer.optimizer.swats import SWATS -from pytorch_optimizer.optimizer.tiger import Tiger -from pytorch_optimizer.optimizer.trac import TRAC -from pytorch_optimizer.optimizer.utils import ( - CPUOffloadOptimizer, - clip_grad_norm, - disable_running_stats, - enable_running_stats, - get_global_gradient_norm, - get_optimizer_parameters, - normalize_gradient, - reduce_max_except_dim, - unit_norm, -) -from pytorch_optimizer.optimizer.yogi import Yogi - -HAS_BNB: bool = find_spec('bitsandbytes') is not None -HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None -HAS_TORCHAO: bool = find_spec('torchao') is not None - -OPTIMIZER_LIST: List[OPTIMIZER] = [ - AdamW, - AdaBelief, - AdaBound, - PID, - AdamP, - Adai, - Adan, - AdaMod, - AdaPNM, - DiffGrad, - Lamb, +from pytorch_optimizer.optimizer import ( + ADOPT, + ASGD, + BSAM, + CAME, + FTRL, + GSAM, LARS, - QHAdam, - QHM, + LOMO, MADGRAD, - Nero, - PNM, MSVAG, - RAdam, - Ranger, - Ranger21, + PID, + PNM, + QHM, + SAM, SGDP, - Shampoo, - ScalableShampoo, - DAdaptAdaGrad, - Fromage, - AggMo, - DAdaptAdam, - DAdaptSGD, - DAdaptAdan, - AdamS, - AdaFactor, - Apollo, - SWATS, - NovoGrad, - Lion, - AliG, + SGDW, SM3, - AdaNorm, + SOAP, + SRMM, + SWATS, + TRAC, + WSAM, A2Grad, AccSGD, - SGDW, - Yogi, - ASGD, + AdaBelief, + AdaBound, + AdaDelta, + AdaFactor, + AdaHessian, + Adai, + Adalite, + AdaLOMO, AdaMax, - Gravity, - AdaSmooth, - SRMM, - AvaGrad, + AdamG, + AdamMini, + AdaMod, + AdamP, + AdamS, + AdamW, + Adan, + AdaNorm, + AdaPNM, AdaShift, - AdaDelta, + AdaSmooth, + AdEMAMix, + AggMo, + Aida, + AliG, Amos, - AdaHessian, - SophiaH, - SignSGD, - Prodigy, - PAdam, - LOMO, - Tiger, - CAME, + Apollo, + AvaGrad, + DAdaptAdaGrad, + DAdaptAdam, + DAdaptAdan, DAdaptLion, - Aida, - GaLore, - Adalite, - BSAM, - ScheduleFreeSGD, - ScheduleFreeAdamW, + DAdaptSGD, + DiffGrad, + DynamicLossScaler, FAdam, + Fromage, + GaLore, + Gravity, GrokFastAdamW, Kate, + Lamb, + Lion, + Lookahead, + Nero, + NovoGrad, + PAdam, + PCGrad, + Prodigy, + QHAdam, + RAdam, + Ranger, + Ranger21, + RotoGrad, + SafeFP16Optimizer, + ScalableShampoo, + ScheduleFreeAdamW, + ScheduleFreeSGD, + Shampoo, + SignSGD, + SophiaH, StableAdamW, - AdamMini, - AdaLOMO, - AdamG, - AdEMAMix, - SOAP, - ADOPT, - FTRL, -] -OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} - -LR_SCHEDULER_LIST: Dict = { - SchedulerType.CONSTANT: ConstantLR, - SchedulerType.STEP: StepLR, - SchedulerType.MULTI_STEP: MultiStepLR, - SchedulerType.CYCLIC: CyclicLR, - SchedulerType.MULTIPLICATIVE: MultiplicativeLR, - SchedulerType.ONE_CYCLE: OneCycleLR, - SchedulerType.COSINE: CosineScheduler, - SchedulerType.POLY: PolyScheduler, - SchedulerType.LINEAR: LinearScheduler, - SchedulerType.PROPORTION: ProportionScheduler, - SchedulerType.COSINE_ANNEALING: CosineAnnealingLR, - SchedulerType.COSINE_ANNEALING_WITH_WARMUP: CosineAnnealingWarmupRestarts, - SchedulerType.COSINE_ANNEALING_WITH_WARM_RESTART: CosineAnnealingWarmRestarts, - SchedulerType.CHEBYSHEV: get_chebyshev_schedule, - SchedulerType.REX: REXScheduler, - SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule, -} -LR_SCHEDULERS: Dict[str, SCHEDULER] = { - str(lr_scheduler_name).lower(): lr_scheduler for lr_scheduler_name, lr_scheduler in LR_SCHEDULER_LIST.items() -} - -LOSS_FUNCTION_LIST: List = [ - BCELoss, - BCEFocalLoss, - FocalLoss, - SoftF1Loss, - DiceLoss, - LDAMLoss, - FocalCosineLoss, - JaccardLoss, - BiTemperedLogisticLoss, - BinaryBiTemperedLogisticLoss, - TverskyLoss, - FocalTverskyLoss, - LovaszHingeLoss, -] -LOSS_FUNCTIONS: Dict[str, nn.Module] = { - str(loss_function.__name__).lower(): loss_function for loss_function in LOSS_FUNCTION_LIST -} - - -def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover - r"""load bnb optimizer instance.""" - from bitsandbytes import optim - - if 'sgd8bit' in optimizer: - return optim.SGD8bit - if 'adam8bit' in optimizer: - return optim.Adam8bit - if 'paged_adam8bit' in optimizer: - return optim.PagedAdam8bit - if 'adamw8bit' in optimizer: - return optim.AdamW8bit - if 'paged_adamw8bit' in optimizer: - return optim.PagedAdamW8bit - if 'lamb8bit' in optimizer: - return optim.LAMB8bit - if 'lars8bit' in optimizer: - return optim.LARS8bit - if 'lion8bit' in optimizer: - return optim.Lion8bit - if 'adagrad8bit' in optimizer: - return optim.Adagrad8bit - if 'rmsprop8bit' in optimizer: - return optim.RMSprop8bit - if 'adagrad32bit' in optimizer: - return optim.Adagrad32bit - if 'adam32bit' in optimizer: - return optim.Adam32bit - if 'paged_adam32bit' in optimizer: - return optim.PagedAdam32bit - if 'adamw32bit' in optimizer: - return optim.AdamW32bit - if 'lamb32bit' in optimizer: - return optim.LAMB32bit - if 'lars32bit' in optimizer: - return optim.LARS32bit - if 'lion32bit' in optimizer: - return optim.Lion32bit - if 'paged_lion32bit' in optimizer: - return optim.PagedLion32bit - if 'rmsprop32bit' in optimizer: - return optim.RMSprop32bit - if 'sgd32bit' in optimizer: - return optim.SGD32bit - if 'ademamix8bit' in optimizer: - return optim.AdEMAMix8bit - if 'ademamix32bit' in optimizer: - return optim.AdEMAMix32bit - if 'paged_ademamix8bit' in optimizer: - return optim.PagedAdEMAMix8bit - if 'paged_ademamix32bit' in optimizer: - return optim.PagedAdEMAMix32bit - - raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') - - -def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover - r"""load Q-GaLore optimizer instance.""" - import q_galore_torch - - if 'adamw8bit' in optimizer: - return q_galore_torch.QGaLoreAdamW8bit - - raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') - - -def load_ao_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover - r"""load TorchAO optimizer instance.""" - from torchao.prototype import low_bit_optim - - if 'adamw8bit' in optimizer: - return low_bit_optim.AdamW8bit - if 'adamw4bit' in optimizer: - return low_bit_optim.AdamW4bit - if 'adamwfp8' in optimizer: - return low_bit_optim.AdamWFp8 - - raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') - - -def load_optimizer(optimizer: str) -> OPTIMIZER: - optimizer: str = optimizer.lower() - - if optimizer.startswith('bnb'): - if HAS_BNB and torch.cuda.is_available(): - return load_bnb_optimizer(optimizer) # pragma: no cover - raise ImportError(f'bitsandbytes and CUDA required for the optimizer {optimizer}') - if optimizer.startswith('q_galore'): - if HAS_Q_GALORE and torch.cuda.is_available(): - return load_q_galore_optimizer(optimizer) # pragma: no cover - raise ImportError(f'bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}') - if optimizer.startswith('torchao'): - if HAS_TORCHAO and torch.cuda.is_available(): - return load_ao_optimizer(optimizer) # pragma: no cover - raise ImportError( - f'torchao required for the optimizer {optimizer}. ' - 'usage: https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#usage' - ) - if optimizer not in OPTIMIZERS: - raise NotImplementedError(f'not implemented optimizer : {optimizer}') - - return OPTIMIZERS[optimizer] - - -def create_optimizer( - model: nn.Module, - optimizer_name: str, - lr: float = 1e-3, - weight_decay: float = 0.0, - wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'), - use_lookahead: bool = False, - **kwargs, -): - r"""Build optimizer. - - :param model: nn.Module. model. - :param optimizer_name: str. name of optimizer. - :param lr: float. learning rate. - :param weight_decay: float. weight decay. - :param wd_ban_list: List[str]. weight decay ban list by layer. - :param use_lookahead: bool. use lookahead. - """ - optimizer_name = optimizer_name.lower() - - parameters = ( - get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters() - ) - - optimizer = load_optimizer(optimizer_name) - - if optimizer_name == 'alig': - optimizer = optimizer(parameters, max_lr=lr, **kwargs) - elif optimizer_name in {'lomo', 'adalomo', 'adammini'}: - optimizer = optimizer(model, lr=lr, **kwargs) - else: - optimizer = optimizer(parameters, lr=lr, **kwargs) - - if use_lookahead: - optimizer = Lookahead( - optimizer, - k=kwargs['k'] if 'k' in kwargs else 5, - alpha=kwargs['alpha'] if 'alpha' in kwargs else 0.5, - pullback_momentum=kwargs['pullback_momentum'] if 'pullback_momentum' in kwargs else 'none', - ) - - return optimizer - - -def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER: - lr_scheduler: str = lr_scheduler.lower() - - if lr_scheduler not in LR_SCHEDULERS: - raise NotImplementedError(f'[-] not implemented lr_scheduler : {lr_scheduler}') - - return LR_SCHEDULERS[lr_scheduler] - - -def get_supported_optimizers(filters: Optional[Union[str, List[str]]] = None) -> List[str]: - r"""Return list of available optimizer names, sorted alphabetically. - - :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will - return the whole list. - """ - if filters is None: - return sorted(OPTIMIZERS.keys()) - - include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] - - filtered_list: Set[str] = set() - for include_filter in include_filters: - filtered_list.update(fnmatch.filter(OPTIMIZERS.keys(), include_filter)) - - return sorted(filtered_list) - - -def get_supported_lr_schedulers(filters: Optional[Union[str, List[str]]] = None) -> List[str]: - r"""Return list of available lr scheduler names, sorted alphabetically. - - :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will - return the whole list. - """ - if filters is None: - return sorted(LR_SCHEDULERS.keys()) - - include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] - - filtered_list: Set[str] = set() - for include_filter in include_filters: - filtered_list.update(fnmatch.filter(LR_SCHEDULERS.keys(), include_filter)) - - return sorted(filtered_list) - - -def get_supported_loss_functions(filters: Optional[Union[str, List[str]]] = None) -> List[str]: - r"""Return list of available loss function names, sorted alphabetically. - - :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will - return the whole list. - """ - if filters is None: - return sorted(LOSS_FUNCTIONS.keys()) - - include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] - - filtered_list: Set[str] = set() - for include_filter in include_filters: - filtered_list.update(fnmatch.filter(LOSS_FUNCTIONS.keys(), include_filter)) - - return sorted(filtered_list) + Tiger, + Yogi, + agc, + centralize_gradient, + create_optimizer, + get_optimizer_parameters, + get_supported_optimizers, + load_ao_optimizer, + load_bnb_optimizer, + load_optimizer, + load_q_galore_optimizer, +) +from pytorch_optimizer.optimizer.utils import ( + CPUOffloadOptimizer, + clip_grad_norm, + disable_running_stats, + enable_running_stats, + get_global_gradient_norm, + normalize_gradient, + unit_norm, +) diff --git a/pytorch_optimizer/loss/__init__.py b/pytorch_optimizer/loss/__init__.py index e69de29bb..c9049c441 100644 --- a/pytorch_optimizer/loss/__init__.py +++ b/pytorch_optimizer/loss/__init__.py @@ -0,0 +1,55 @@ +import fnmatch +from typing import Dict, List, Optional, Sequence, Set, Union + +from torch import nn + +from pytorch_optimizer.loss.bi_tempered import ( + BinaryBiTemperedLogisticLoss, + BiTemperedLogisticLoss, + bi_tempered_logistic_loss, +) +from pytorch_optimizer.loss.cross_entropy import BCELoss +from pytorch_optimizer.loss.dice import DiceLoss, soft_dice_score +from pytorch_optimizer.loss.f1 import SoftF1Loss +from pytorch_optimizer.loss.focal import BCEFocalLoss, FocalCosineLoss, FocalLoss, FocalTverskyLoss +from pytorch_optimizer.loss.jaccard import JaccardLoss, soft_jaccard_score +from pytorch_optimizer.loss.ldam import LDAMLoss +from pytorch_optimizer.loss.lovasz import LovaszHingeLoss +from pytorch_optimizer.loss.tversky import TverskyLoss + +LOSS_FUNCTION_LIST: List = [ + BCELoss, + BCEFocalLoss, + FocalLoss, + SoftF1Loss, + DiceLoss, + LDAMLoss, + FocalCosineLoss, + JaccardLoss, + BiTemperedLogisticLoss, + BinaryBiTemperedLogisticLoss, + TverskyLoss, + FocalTverskyLoss, + LovaszHingeLoss, +] +LOSS_FUNCTIONS: Dict[str, nn.Module] = { + str(loss_function.__name__).lower(): loss_function for loss_function in LOSS_FUNCTION_LIST +} + + +def get_supported_loss_functions(filters: Optional[Union[str, List[str]]] = None) -> List[str]: + r"""Return list of available loss function names, sorted alphabetically. + + :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will + return the whole list. + """ + if filters is None: + return sorted(LOSS_FUNCTIONS.keys()) + + include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] + + filtered_list: Set[str] = set() + for include_filter in include_filters: + filtered_list.update(fnmatch.filter(LOSS_FUNCTIONS.keys(), include_filter)) + + return sorted(filtered_list) diff --git a/pytorch_optimizer/lr_scheduler/__init__.py b/pytorch_optimizer/lr_scheduler/__init__.py index 842f0cc0f..de0bf5a40 100644 --- a/pytorch_optimizer/lr_scheduler/__init__.py +++ b/pytorch_optimizer/lr_scheduler/__init__.py @@ -1,5 +1,7 @@ # ruff: noqa +import fnmatch from enum import Enum +from typing import Dict, List, Optional, Sequence, Set, Union from torch.optim.lr_scheduler import ( ConstantLR, @@ -12,6 +14,15 @@ StepLR, ) +from pytorch_optimizer.base.types import SCHEDULER +from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_perm_steps, get_chebyshev_schedule +from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts +from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler +from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler +from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler +from pytorch_optimizer.lr_scheduler.rex import REXScheduler +from pytorch_optimizer.lr_scheduler.wsd import get_wsd_schedule + class SchedulerType(Enum): CONSTANT = 'constant' @@ -33,3 +44,53 @@ class SchedulerType(Enum): def __str__(self) -> str: return self.value + + +LR_SCHEDULER_LIST: Dict = { + SchedulerType.CONSTANT: ConstantLR, + SchedulerType.STEP: StepLR, + SchedulerType.MULTI_STEP: MultiStepLR, + SchedulerType.CYCLIC: CyclicLR, + SchedulerType.MULTIPLICATIVE: MultiplicativeLR, + SchedulerType.ONE_CYCLE: OneCycleLR, + SchedulerType.COSINE: CosineScheduler, + SchedulerType.POLY: PolyScheduler, + SchedulerType.LINEAR: LinearScheduler, + SchedulerType.PROPORTION: ProportionScheduler, + SchedulerType.COSINE_ANNEALING: CosineAnnealingLR, + SchedulerType.COSINE_ANNEALING_WITH_WARMUP: CosineAnnealingWarmupRestarts, + SchedulerType.COSINE_ANNEALING_WITH_WARM_RESTART: CosineAnnealingWarmRestarts, + SchedulerType.CHEBYSHEV: get_chebyshev_schedule, + SchedulerType.REX: REXScheduler, + SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule, +} +LR_SCHEDULERS: Dict[str, SCHEDULER] = { + str(lr_scheduler_name).lower(): lr_scheduler for lr_scheduler_name, lr_scheduler in LR_SCHEDULER_LIST.items() +} + + +def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER: + lr_scheduler: str = lr_scheduler.lower() + + if lr_scheduler not in LR_SCHEDULERS: + raise NotImplementedError(f'[-] not implemented lr_scheduler : {lr_scheduler}') + + return LR_SCHEDULERS[lr_scheduler] + + +def get_supported_lr_schedulers(filters: Optional[Union[str, List[str]]] = None) -> List[str]: + r"""Return list of available lr scheduler names, sorted alphabetically. + + :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will + return the whole list. + """ + if filters is None: + return sorted(LR_SCHEDULERS.keys()) + + include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] + + filtered_list: Set[str] = set() + for include_filter in include_filters: + filtered_list.update(fnmatch.filter(LR_SCHEDULERS.keys(), include_filter)) + + return sorted(filtered_list) diff --git a/pytorch_optimizer/optimizer/__init__.py b/pytorch_optimizer/optimizer/__init__.py index e69de29bb..73c92c1ce 100644 --- a/pytorch_optimizer/optimizer/__init__.py +++ b/pytorch_optimizer/optimizer/__init__.py @@ -0,0 +1,387 @@ +import fnmatch +from importlib.util import find_spec +from typing import Dict, List, Optional, Sequence, Set, Union + +import torch +from torch import nn +from torch.optim import AdamW + +from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS +from pytorch_optimizer.optimizer.a2grad import A2Grad +from pytorch_optimizer.optimizer.adabelief import AdaBelief +from pytorch_optimizer.optimizer.adabound import AdaBound +from pytorch_optimizer.optimizer.adadelta import AdaDelta +from pytorch_optimizer.optimizer.adafactor import AdaFactor +from pytorch_optimizer.optimizer.adahessian import AdaHessian +from pytorch_optimizer.optimizer.adai import Adai +from pytorch_optimizer.optimizer.adalite import Adalite +from pytorch_optimizer.optimizer.adam_mini import AdamMini +from pytorch_optimizer.optimizer.adamax import AdaMax +from pytorch_optimizer.optimizer.adamg import AdamG +from pytorch_optimizer.optimizer.adamod import AdaMod +from pytorch_optimizer.optimizer.adamp import AdamP +from pytorch_optimizer.optimizer.adams import AdamS +from pytorch_optimizer.optimizer.adamw import StableAdamW +from pytorch_optimizer.optimizer.adan import Adan +from pytorch_optimizer.optimizer.adanorm import AdaNorm +from pytorch_optimizer.optimizer.adapnm import AdaPNM +from pytorch_optimizer.optimizer.adashift import AdaShift +from pytorch_optimizer.optimizer.adasmooth import AdaSmooth +from pytorch_optimizer.optimizer.ademamix import AdEMAMix +from pytorch_optimizer.optimizer.adopt import ADOPT +from pytorch_optimizer.optimizer.agc import agc +from pytorch_optimizer.optimizer.aggmo import AggMo +from pytorch_optimizer.optimizer.aida import Aida +from pytorch_optimizer.optimizer.alig import AliG +from pytorch_optimizer.optimizer.amos import Amos +from pytorch_optimizer.optimizer.apollo import Apollo +from pytorch_optimizer.optimizer.avagrad import AvaGrad +from pytorch_optimizer.optimizer.came import CAME +from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptLion, DAdaptSGD +from pytorch_optimizer.optimizer.diffgrad import DiffGrad +from pytorch_optimizer.optimizer.fadam import FAdam +from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer +from pytorch_optimizer.optimizer.fromage import Fromage +from pytorch_optimizer.optimizer.ftrl import FTRL +from pytorch_optimizer.optimizer.galore import GaLore +from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.optimizer.gravity import Gravity +from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW +from pytorch_optimizer.optimizer.kate import Kate +from pytorch_optimizer.optimizer.lamb import Lamb +from pytorch_optimizer.optimizer.lars import LARS +from pytorch_optimizer.optimizer.lion import Lion +from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO +from pytorch_optimizer.optimizer.lookahead import Lookahead +from pytorch_optimizer.optimizer.madgrad import MADGRAD +from pytorch_optimizer.optimizer.msvag import MSVAG +from pytorch_optimizer.optimizer.nero import Nero +from pytorch_optimizer.optimizer.novograd import NovoGrad +from pytorch_optimizer.optimizer.padam import PAdam +from pytorch_optimizer.optimizer.pcgrad import PCGrad +from pytorch_optimizer.optimizer.pid import PID +from pytorch_optimizer.optimizer.pnm import PNM +from pytorch_optimizer.optimizer.prodigy import Prodigy +from pytorch_optimizer.optimizer.qhadam import QHAdam +from pytorch_optimizer.optimizer.qhm import QHM +from pytorch_optimizer.optimizer.radam import RAdam +from pytorch_optimizer.optimizer.ranger import Ranger +from pytorch_optimizer.optimizer.ranger21 import Ranger21 +from pytorch_optimizer.optimizer.rotograd import RotoGrad +from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM +from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeSGD +from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD +from pytorch_optimizer.optimizer.sgdp import SGDP +from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo +from pytorch_optimizer.optimizer.sm3 import SM3 +from pytorch_optimizer.optimizer.soap import SOAP +from pytorch_optimizer.optimizer.sophia import SophiaH +from pytorch_optimizer.optimizer.srmm import SRMM +from pytorch_optimizer.optimizer.swats import SWATS +from pytorch_optimizer.optimizer.tiger import Tiger +from pytorch_optimizer.optimizer.trac import TRAC +from pytorch_optimizer.optimizer.yogi import Yogi + +HAS_BNB: bool = find_spec('bitsandbytes') is not None +HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None +HAS_TORCHAO: bool = find_spec('torchao') is not None + + +def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover # noqa: PLR0911 + r"""Load bnb optimizer instance.""" + from bitsandbytes import optim + + if 'sgd8bit' in optimizer: + return optim.SGD8bit + if 'adam8bit' in optimizer: + return optim.Adam8bit + if 'paged_adam8bit' in optimizer: + return optim.PagedAdam8bit + if 'adamw8bit' in optimizer: + return optim.AdamW8bit + if 'paged_adamw8bit' in optimizer: + return optim.PagedAdamW8bit + if 'lamb8bit' in optimizer: + return optim.LAMB8bit + if 'lars8bit' in optimizer: + return optim.LARS8bit + if 'lion8bit' in optimizer: + return optim.Lion8bit + if 'adagrad8bit' in optimizer: + return optim.Adagrad8bit + if 'rmsprop8bit' in optimizer: + return optim.RMSprop8bit + if 'adagrad32bit' in optimizer: + return optim.Adagrad32bit + if 'adam32bit' in optimizer: + return optim.Adam32bit + if 'paged_adam32bit' in optimizer: + return optim.PagedAdam32bit + if 'adamw32bit' in optimizer: + return optim.AdamW32bit + if 'lamb32bit' in optimizer: + return optim.LAMB32bit + if 'lars32bit' in optimizer: + return optim.LARS32bit + if 'lion32bit' in optimizer: + return optim.Lion32bit + if 'paged_lion32bit' in optimizer: + return optim.PagedLion32bit + if 'rmsprop32bit' in optimizer: + return optim.RMSprop32bit + if 'sgd32bit' in optimizer: + return optim.SGD32bit + if 'ademamix8bit' in optimizer: + return optim.AdEMAMix8bit + if 'ademamix32bit' in optimizer: + return optim.AdEMAMix32bit + if 'paged_ademamix8bit' in optimizer: + return optim.PagedAdEMAMix8bit + if 'paged_ademamix32bit' in optimizer: + return optim.PagedAdEMAMix32bit + + raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') + + +def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover + r"""Load Q-GaLore optimizer instance.""" + import q_galore_torch + + if 'adamw8bit' in optimizer: + return q_galore_torch.QGaLoreAdamW8bit + + raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') + + +def load_ao_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover + r"""Load TorchAO optimizer instance.""" + from torchao.prototype import low_bit_optim + + if 'adamw8bit' in optimizer: + return low_bit_optim.AdamW8bit + if 'adamw4bit' in optimizer: + return low_bit_optim.AdamW4bit + if 'adamwfp8' in optimizer: + return low_bit_optim.AdamWFp8 + + raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') + + +def load_optimizer(optimizer: str) -> OPTIMIZER: + r"""Load optimizers.""" + optimizer: str = optimizer.lower() + + if optimizer.startswith('bnb'): + if HAS_BNB and torch.cuda.is_available(): + return load_bnb_optimizer(optimizer) # pragma: no cover + raise ImportError(f'bitsandbytes and CUDA required for the optimizer {optimizer}') + if optimizer.startswith('q_galore'): + if HAS_Q_GALORE and torch.cuda.is_available(): + return load_q_galore_optimizer(optimizer) # pragma: no cover + raise ImportError(f'bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}') + if optimizer.startswith('torchao'): + if HAS_TORCHAO and torch.cuda.is_available(): + return load_ao_optimizer(optimizer) # pragma: no cover + raise ImportError( + f'torchao required for the optimizer {optimizer}. ' + 'usage: https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#usage' + ) + if optimizer not in OPTIMIZERS: + raise NotImplementedError(f'not implemented optimizer : {optimizer}') + + return OPTIMIZERS[optimizer] + + +OPTIMIZER_LIST: List[OPTIMIZER] = [ + AdamW, + AdaBelief, + AdaBound, + PID, + AdamP, + Adai, + Adan, + AdaMod, + AdaPNM, + DiffGrad, + Lamb, + LARS, + QHAdam, + QHM, + MADGRAD, + Nero, + PNM, + MSVAG, + RAdam, + Ranger, + Ranger21, + SGDP, + Shampoo, + ScalableShampoo, + DAdaptAdaGrad, + Fromage, + AggMo, + DAdaptAdam, + DAdaptSGD, + DAdaptAdan, + AdamS, + AdaFactor, + Apollo, + SWATS, + NovoGrad, + Lion, + AliG, + SM3, + AdaNorm, + A2Grad, + AccSGD, + SGDW, + Yogi, + ASGD, + AdaMax, + Gravity, + AdaSmooth, + SRMM, + AvaGrad, + AdaShift, + AdaDelta, + Amos, + AdaHessian, + SophiaH, + SignSGD, + Prodigy, + PAdam, + LOMO, + Tiger, + CAME, + DAdaptLion, + Aida, + GaLore, + Adalite, + BSAM, + ScheduleFreeSGD, + ScheduleFreeAdamW, + FAdam, + GrokFastAdamW, + Kate, + StableAdamW, + AdamMini, + AdaLOMO, + AdamG, + AdEMAMix, + SOAP, + ADOPT, + FTRL, +] +OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} + + +def create_optimizer( + model: nn.Module, + optimizer_name: str, + lr: float = 1e-3, + weight_decay: float = 0.0, + wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'), + use_lookahead: bool = False, + **kwargs, +): + r"""Build optimizer. + + :param model: nn.Module. model. + :param optimizer_name: str. name of optimizer. + :param lr: float. learning rate. + :param weight_decay: float. weight decay. + :param wd_ban_list: List[str]. weight decay ban list by layer. + :param use_lookahead: bool. use lookahead. + """ + optimizer_name = optimizer_name.lower() + + parameters = ( + get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters() + ) + + optimizer = load_optimizer(optimizer_name) + + if optimizer_name == 'alig': + optimizer = optimizer(parameters, max_lr=lr, **kwargs) + elif optimizer_name in {'lomo', 'adalomo', 'adammini'}: + optimizer = optimizer(model, lr=lr, **kwargs) + else: + optimizer = optimizer(parameters, lr=lr, **kwargs) + + if use_lookahead: + optimizer = Lookahead( + optimizer, + k=kwargs.get('k', 5), + alpha=kwargs.get('alpha', 0.5), + pullback_momentum=kwargs.get('pullback_momentum', 'none'), + ) + + return optimizer + + +def get_optimizer_parameters( + model_or_parameter: Union[nn.Module, List], + weight_decay: float, + wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'), +) -> PARAMETERS: + r"""Get optimizer parameters while filtering specified modules. + + Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only + need to input `LayerNorm` to exclude weight decay from the layer norm layer(s). + + :param model_or_parameter: Union[nn.Module, List]. model or parameters. + :param weight_decay: float. weight_decay. + :param wd_ban_list: List[str]. ban list not to set weight decay. + :returns: PARAMETERS. new parameter list. + """ + banned_parameter_patterns: Set[str] = set() + + if isinstance(model_or_parameter, nn.Module): + for module_name, module in model_or_parameter.named_modules(): + for param_name, _ in module.named_parameters(recurse=False): + full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name + if any( + banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name()) + ): + banned_parameter_patterns.add(full_param_name) + + model_or_parameter = list(model_or_parameter.named_parameters()) + else: + banned_parameter_patterns.update(wd_ban_list) + + return [ + { + 'params': [ + p + for n, p in model_or_parameter + if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns) + ], + 'weight_decay': weight_decay, + }, + { + 'params': [ + p + for n, p in model_or_parameter + if p.requires_grad and any(nd in n for nd in banned_parameter_patterns) + ], + 'weight_decay': 0.0, + }, + ] + + +def get_supported_optimizers(filters: Optional[Union[str, List[str]]] = None) -> List[str]: + r"""Return list of available optimizer names, sorted alphabetically. + + :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will + return the whole list. + """ + if filters is None: + return sorted(OPTIMIZERS.keys()) + + include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] + + filtered_list: Set[str] = set() + for include_filter in include_filters: + filtered_list.update(fnmatch.filter(OPTIMIZERS.keys(), include_filter)) + + return sorted(filtered_list) diff --git a/pytorch_optimizer/optimizer/adopt.py b/pytorch_optimizer/optimizer/adopt.py index f23699ea3..8b7026dbc 100644 --- a/pytorch_optimizer/optimizer/adopt.py +++ b/pytorch_optimizer/optimizer/adopt.py @@ -14,9 +14,6 @@ class ADOPT(BaseOptimizer): :param weight_decay: float. weight decay (L2 penalty). :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. :param fixed_decay: bool. fix weight decay. - :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred. - :param adanorm: bool. whether to use the AdaNorm variant. - :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training. :param eps: float. term added to the denominator to improve numerical stability. """ diff --git a/pytorch_optimizer/optimizer/alig.py b/pytorch_optimizer/optimizer/alig.py index 2931f35a2..0240a5299 100644 --- a/pytorch_optimizer/optimizer/alig.py +++ b/pytorch_optimizer/optimizer/alig.py @@ -8,6 +8,16 @@ from pytorch_optimizer.optimizer.utils import get_global_gradient_norm +@torch.no_grad() +def l2_projection(parameters: PARAMETERS, max_norm: float = 1e2): + r"""Get l2 normalized parameter.""" + global_norm = torch.sqrt(sum(p.norm().pow(2) for p in parameters)) + if global_norm > max_norm: + ratio = max_norm / global_norm + for param in parameters: + param.mul_(ratio) + + class AliG(BaseOptimizer): r"""Adaptive Learning Rates for Interpolation with Gradients. diff --git a/pytorch_optimizer/optimizer/grokfast.py b/pytorch_optimizer/optimizer/grokfast.py index df2ed50e5..8e4ee61dc 100644 --- a/pytorch_optimizer/optimizer/grokfast.py +++ b/pytorch_optimizer/optimizer/grokfast.py @@ -106,7 +106,7 @@ class GrokFastAdamW(BaseOptimizer): :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. :param grokfast: bool. whether to use grokfast. :param grokfast_alpha: float. momentum hyperparameter of the EMA. - :param grokfast_lamb: float. amplifying factor hyperparameter of the filter.. + :param grokfast_lamb: float. amplifying factor hyperparameter of the filter. :param grokfast_after_step: int. warmup step for grokfast. :param weight_decay: float. weight decay (L2 penalty). :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. diff --git a/pytorch_optimizer/optimizer/nero.py b/pytorch_optimizer/optimizer/nero.py index bf4c594ca..436227882 100644 --- a/pytorch_optimizer/optimizer/nero.py +++ b/pytorch_optimizer/optimizer/nero.py @@ -1,9 +1,31 @@ +from typing import List + import torch from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.utils import neuron_mean, neuron_norm +from pytorch_optimizer.optimizer.utils import channel_view + + +def neuron_norm(x: torch.Tensor) -> torch.Tensor: + r"""Get norm of the tensor.""" + if x.dim() <= 1: + return x.abs() + + view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1) + + return channel_view(x).norm(dim=1).view(*view_shape) + + +def neuron_mean(x: torch.Tensor) -> torch.Tensor: + r"""Get mean of the tensor.""" + if x.dim() <= 1: + raise ValueError('[-] neuron_mean not defined on 1D tensors.') + + view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1) + + return channel_view(x).mean(dim=1).view(*view_shape) class Nero(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/pcgrad.py b/pytorch_optimizer/optimizer/pcgrad.py index d2af6afce..f4b428892 100644 --- a/pytorch_optimizer/optimizer/pcgrad.py +++ b/pytorch_optimizer/optimizer/pcgrad.py @@ -2,12 +2,28 @@ from copy import deepcopy from typing import Iterable, List, Tuple +import numpy as np import torch from torch import nn from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import OPTIMIZER -from pytorch_optimizer.optimizer.utils import flatten_grad, un_flatten_grad + + +def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor: + r"""Flatten the gradient.""" + return torch.cat([grad.flatten() for grad in grads]) + + +def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor]: + r"""Unflatten the gradient.""" + idx: int = 0 + un_flatten_grads: List[torch.Tensor] = [] + for shape in shapes: + length = np.prod(shape) + un_flatten_grads.append(grads[idx:idx + length].view(shape).clone()) # fmt: skip + idx += length + return un_flatten_grads class PCGrad(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/sm3.py b/pytorch_optimizer/optimizer/sm3.py index 984d98103..78718a591 100644 --- a/pytorch_optimizer/optimizer/sm3.py +++ b/pytorch_optimizer/optimizer/sm3.py @@ -2,7 +2,26 @@ from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.utils import reduce_max_except_dim + + +@torch.no_grad() +def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor: + r"""Perform reduce-max along all dimensions except the given dim. + + :param x: torch.Tensor. tensor to reduce-max. + :param dim: int. dimension to exclude. + """ + rank: int = len(x.shape) + if rank == 0: + return x + + if dim >= rank: + raise ValueError(f'[-] given dim is bigger than rank. {dim} >= {rank}') + + for d in range(rank): + if d != dim: + x = x.max(dim=d, keepdim=True).values + return x class SM3(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 4bb7b8673..779fe8c74 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -4,9 +4,8 @@ import re import warnings from importlib.util import find_spec -from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union -import numpy as np import torch from torch import nn from torch.distributed import all_reduce @@ -16,7 +15,25 @@ from pytorch_optimizer.base.types import CLOSURE, LOSS, PARAMETERS + +def parse_pytorch_version(version_string: str) -> List[int]: + r"""Parse Pytorch version.""" + match = re.match(r'(\d+\.\d+\.\d+)', version_string) + if not match: + raise ValueError(f'invalid version string format: {version_string}') + + return [int(x) for x in match.group(1).split('.')] + + +def compare_versions(v1: str, v2: str) -> bool: + r"""Compare two Pytorch versions.""" + v1_parts: List[int] = parse_pytorch_version(v1) + v2_parts: List[int] = parse_pytorch_version(v2) + return (v1_parts > v2_parts) - (v1_parts < v2_parts) + + HAS_TRANSFORMERS: bool = find_spec('transformers') is not None +TORCH_VERSION_AT_LEAST_2_4: bool = compare_versions(torch.__version__, '2.4.0') if HAS_TRANSFORMERS: # pragma: no cover try: @@ -39,25 +56,6 @@ def is_deepspeed_zero3_enabled() -> bool: return False -def parse_pytorch_version(version_string: str) -> List[int]: - r"""Parse Pytorch version.""" - match = re.match(r'(\d+\.\d+\.\d+)', version_string) - if not match: - raise ValueError(f'invalid version string format: {version_string}') - - return [int(x) for x in match.group(1).split('.')] - - -def compare_versions(v1: str, v2: str) -> bool: - r"""Compare two Pytorch versions.""" - v1_parts: List[int] = parse_pytorch_version(v1) - v2_parts: List[int] = parse_pytorch_version(v2) - return (v1_parts > v2_parts) - (v1_parts < v2_parts) - - -TORCH_VERSION_AT_LEAST_2_4: bool = compare_versions(torch.__version__, '2.4.0') - - class CPUOffloadOptimizer: # pragma: no cover """Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state. @@ -191,22 +189,6 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo x.div_(s) -def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor: - r"""Flatten the gradient.""" - return torch.cat([grad.flatten() for grad in grads]) - - -def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor]: - r"""Unflatten the gradient.""" - idx: int = 0 - un_flatten_grads: List[torch.Tensor] = [] - for shape in shapes: - length = np.prod(shape) - un_flatten_grads.append(grads[idx:idx + length].view(shape).clone()) # fmt: skip - idx += length - return un_flatten_grads - - def channel_view(x: torch.Tensor) -> torch.Tensor: r"""Do channel view.""" return x.view(x.size()[0], -1) @@ -307,9 +289,9 @@ def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor: x_len: int = len(x.shape) if x_len <= 1: keep_dim = False - elif x_len in (2, 3): # linear layers + elif x_len in (2, 3): dim = 1 - elif x_len == 4: # conv kernels + elif x_len == 4: dim = (1, 2, 3) else: dim = tuple(range(1, x_len)) @@ -317,76 +299,6 @@ def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor: return x.norm(p=norm, dim=dim, keepdim=keep_dim) -def get_optimizer_parameters( - model_or_parameter: Union[nn.Module, List], - weight_decay: float, - wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'), -) -> PARAMETERS: - r"""Get optimizer parameters while filtering specified modules. - - Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only - need to input `LayerNorm` to exclude weight decay from the layer norm layer(s). - - :param model_or_parameter: Union[nn.Module, List]. model or parameters. - :param weight_decay: float. weight_decay. - :param wd_ban_list: List[str]. ban list not to set weight decay. - :returns: PARAMETERS. new parameter list. - """ - banned_parameter_patterns: Set[str] = set() - - if isinstance(model_or_parameter, nn.Module): - for module_name, module in model_or_parameter.named_modules(): - for param_name, _ in module.named_parameters(recurse=False): - full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name - if any( - banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name()) - ): - banned_parameter_patterns.add(full_param_name) - - model_or_parameter = list(model_or_parameter.named_parameters()) - else: - banned_parameter_patterns.update(wd_ban_list) - - return [ - { - 'params': [ - p - for n, p in model_or_parameter - if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns) - ], - 'weight_decay': weight_decay, - }, - { - 'params': [ - p - for n, p in model_or_parameter - if p.requires_grad and any(nd in n for nd in banned_parameter_patterns) - ], - 'weight_decay': 0.0, - }, - ] - - -def neuron_norm(x: torch.Tensor) -> torch.Tensor: - r"""Get norm of the tensor.""" - if x.dim() <= 1: - return x.abs() - - view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1) - - return channel_view(x).norm(dim=1).view(*view_shape) - - -def neuron_mean(x: torch.Tensor) -> torch.Tensor: - r"""Get mean of the tensor.""" - if x.dim() <= 1: - raise ValueError('[-] neuron_mean not defined on 1D tensors.') - - view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1) - - return channel_view(x).mean(dim=1).view(*view_shape) - - def disable_running_stats(model): r"""Disable running stats (momentum) of BatchNorm.""" @@ -408,16 +320,6 @@ def _enable(module): model.apply(_enable) -@torch.no_grad() -def l2_projection(parameters: PARAMETERS, max_norm: float = 1e2): - r"""Get l2 normalized parameter.""" - global_norm = torch.sqrt(sum(p.norm().pow(2) for p in parameters)) - if global_norm > max_norm: - ratio = max_norm / global_norm - for param in parameters: - param.mul_(ratio) - - @torch.no_grad() def get_global_gradient_norm(param_groups: List[Dict]) -> torch.Tensor: r"""Get global gradient norm.""" @@ -431,26 +333,6 @@ def get_global_gradient_norm(param_groups: List[Dict]) -> torch.Tensor: return global_grad_norm -@torch.no_grad() -def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor: - r"""Perform reduce-max along all dimensions except the given dim. - - :param x: torch.Tensor. tensor to reduce-max. - :param dim: int. dimension to exclude. - """ - rank: int = len(x.shape) - if rank == 0: - return x - - if dim >= rank: - raise ValueError(f'[-] given dim is bigger than rank. {dim} >= {rank}') - - for d in range(rank): - if d != dim: - x = x.max(dim=d, keepdim=True).values - return x - - @torch.no_grad() def reg_noise( network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4 diff --git a/requirements-docs.txt b/requirements-docs.txt index 0be9adb8a..961a0d126 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,12 +1,12 @@ --index-url https://pypi.org/simple --extra-index-url https://download.pytorch.org/whl/cpu numpy<2.0 -torch==2.3.1 -mkdocs==1.6.0 -mkdocs-material==9.5.29 -pymdown-extensions==10.8.1 -mkdocstrings-python==1.10.5 +torch==2.5.1 +mkdocs==1.6.1 +mkdocs-material==9.5.45 +pymdown-extensions==10.12 +mkdocstrings-python==1.12.2 markdown-include==0.8.1 mdx_truly_sane_lists==1.3 -mkdocs-awesome-pages-plugin==2.9.2 -griffe<1.0 \ No newline at end of file +mkdocs-awesome-pages-plugin==2.9.3 +griffe==1.5.1 diff --git a/tests/constants.py b/tests/constants.py index ee9deceb7..8f2abe124 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Tuple, Union -from pytorch_optimizer import ( +from pytorch_optimizer.optimizer import ( ADOPT, ASGD, CAME, diff --git a/tests/test_create_optimizer.py b/tests/test_create_optimizer.py index c5771ed66..bddd0b185 100644 --- a/tests/test_create_optimizer.py +++ b/tests/test_create_optimizer.py @@ -1,6 +1,6 @@ import pytest -from pytorch_optimizer import create_optimizer, load_optimizer +from pytorch_optimizer.optimizer import create_optimizer, load_optimizer from tests.constants import VALID_OPTIMIZER_NAMES from tests.utils import LogisticRegression diff --git a/tests/test_general_optimizer_parameters.py b/tests/test_general_optimizer_parameters.py index f0972e190..bbcc251e0 100644 --- a/tests/test_general_optimizer_parameters.py +++ b/tests/test_general_optimizer_parameters.py @@ -1,7 +1,7 @@ import pytest -from pytorch_optimizer import PCGrad, load_optimizer from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError, ZeroParameterSizeError +from pytorch_optimizer.optimizer import PCGrad, load_optimizer from tests.constants import BETA_OPTIMIZER_NAMES, VALID_OPTIMIZER_NAMES from tests.utils import Example, simple_parameter diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 435e20d08..3bc8202ce 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,8 +1,8 @@ import pytest import torch -from pytorch_optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, load_optimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, load_optimizer from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index 7bca93d58..f11b3aa3f 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -1,12 +1,8 @@ import pytest -from pytorch_optimizer import ( - get_supported_loss_functions, - get_supported_lr_schedulers, - get_supported_optimizers, - load_lr_scheduler, - load_optimizer, -) +from pytorch_optimizer.loss import get_supported_loss_functions +from pytorch_optimizer.lr_scheduler import get_supported_lr_schedulers, load_lr_scheduler +from pytorch_optimizer.optimizer import get_supported_optimizers, load_optimizer from tests.constants import ( INVALID_LR_SCHEDULER_NAMES, INVALID_OPTIMIZER_NAMES, diff --git a/tests/test_loss_functions.py b/tests/test_loss_functions.py index 17a43019b..8597b3285 100644 --- a/tests/test_loss_functions.py +++ b/tests/test_loss_functions.py @@ -1,7 +1,7 @@ import pytest import torch -from pytorch_optimizer import ( +from pytorch_optimizer.loss import ( BCEFocalLoss, BCELoss, BinaryBiTemperedLogisticLoss, diff --git a/tests/test_lr_scheduler_parameters.py b/tests/test_lr_scheduler_parameters.py index 4de8b4252..89d3d93f3 100644 --- a/tests/test_lr_scheduler_parameters.py +++ b/tests/test_lr_scheduler_parameters.py @@ -1,15 +1,16 @@ import numpy as np import pytest -from pytorch_optimizer import AdamP, get_chebyshev_perm_steps from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError +from pytorch_optimizer.lr_scheduler import get_chebyshev_perm_steps from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts from pytorch_optimizer.lr_scheduler.linear_warmup import PolyScheduler +from pytorch_optimizer.optimizer import AdamW from tests.utils import Example def test_cosine_annealing_warmup_restarts_params(): - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) with pytest.raises(ValueError) as error_info: CosineAnnealingWarmupRestarts( @@ -37,7 +38,7 @@ def test_cosine_annealing_warmup_restarts_params(): def test_linear_warmup_lr_scheduler_params(): - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) with pytest.raises(ValueError) as error_info: PolyScheduler(poly_order=-1, optimizer=optimizer, t_max=1, max_lr=1) diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index dd705e054..2f6f9c9d5 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -4,14 +4,18 @@ import pytest from torch import nn -from pytorch_optimizer import AdamP, get_chebyshev_perm_steps, get_chebyshev_schedule -from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_permutation +from pytorch_optimizer.lr_scheduler.chebyshev import ( + get_chebyshev_perm_steps, + get_chebyshev_permutation, + get_chebyshev_schedule, +) from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler from pytorch_optimizer.lr_scheduler.rex import REXScheduler from pytorch_optimizer.lr_scheduler.wsd import get_wsd_schedule +from pytorch_optimizer.optimizer import AdamW from tests.utils import Example CAWR_RECIPES = [ @@ -120,7 +124,7 @@ @pytest.mark.parametrize('cosine_annealing_warmup_restart_param', CAWR_RECIPES) def test_cosine_annealing_warmup_restarts(cosine_annealing_warmup_restart_param): model = Example() - optimizer = AdamP(model.parameters()) + optimizer = AdamW(model.parameters()) ( first_cycle_steps, @@ -192,7 +196,7 @@ def test_get_chebyshev_lr(): 0.001335267780289186, ] - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) optimizer.step() lr_scheduler = get_chebyshev_schedule(optimizer, num_epochs=16, is_warmup=True) @@ -200,7 +204,7 @@ def test_get_chebyshev_lr(): np.testing.assert_almost_equal(lr_scheduler.get_last_lr(), 1e-3) - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) optimizer.step() lr_scheduler = get_chebyshev_schedule(optimizer, num_epochs=16, is_warmup=False) @@ -211,7 +215,7 @@ def test_get_chebyshev_lr(): def test_linear_warmup_linear_scheduler(): - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) lr_scheduler = LinearScheduler(optimizer, t_max=10, max_lr=1e-2, min_lr=1e-4, init_lr=1e-3, warmup_steps=5) @@ -221,7 +225,7 @@ def test_linear_warmup_linear_scheduler(): def test_linear_warmup_cosine_scheduler(): - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) lr_scheduler = CosineScheduler(optimizer, t_max=10, max_lr=1e-2, min_lr=1e-4, init_lr=1e-3, warmup_steps=5) for expected_lr in LWC_RECIPE: @@ -230,7 +234,7 @@ def test_linear_warmup_cosine_scheduler(): def test_linear_warmup_poly_scheduler(): - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) lr_scheduler = PolyScheduler(optimizer=optimizer, t_max=10, max_lr=1e-2, min_lr=1e-4, init_lr=1e-3, warmup_steps=5) for expected_lr in LWP_RECIPE: @@ -240,7 +244,7 @@ def test_linear_warmup_poly_scheduler(): @pytest.mark.parametrize('proportion_learning_rate', PROPORTION_LEARNING_RATES) def test_proportion_scheduler(proportion_learning_rate: Tuple[float, float, float]): - base_optimizer = AdamP(Example().parameters()) + base_optimizer = AdamW(Example().parameters()) lr_scheduler = CosineScheduler( base_optimizer, t_max=10, max_lr=proportion_learning_rate[0], min_lr=proportion_learning_rate[1], init_lr=1e-2 ) @@ -258,7 +262,7 @@ def test_proportion_scheduler(proportion_learning_rate: Tuple[float, float, floa def test_proportion_no_last_lr_scheduler(): - base_optimizer = AdamP(Example().parameters()) + base_optimizer = AdamW(Example().parameters()) lr_scheduler = CosineAnnealingWarmupRestarts( base_optimizer, first_cycle_steps=10, @@ -287,7 +291,7 @@ def test_rex_lr_scheduler(): 0.0, ] - base_optimizer = AdamP(Example().parameters()) + base_optimizer = AdamW(Example().parameters()) lr_scheduler = REXScheduler( base_optimizer, @@ -302,7 +306,7 @@ def test_rex_lr_scheduler(): def test_wsd_lr_scheduler(): - optimizer = AdamP(Example().parameters()) + optimizer = AdamW(Example().parameters()) optimizer.step() lr_scheduler = get_wsd_schedule(optimizer, 2, 2, 3, min_lr_ratio=0.1) diff --git a/tests/test_optimizer_parameters.py b/tests/test_optimizer_parameters.py index c83e11cbe..e7610e9eb 100644 --- a/tests/test_optimizer_parameters.py +++ b/tests/test_optimizer_parameters.py @@ -2,16 +2,8 @@ import torch from torch import nn -from pytorch_optimizer import ( - SAM, - WSAM, - GaLoreProjector, - Lookahead, - PCGrad, - Ranger21, - SafeFP16Optimizer, - load_optimizer, -) +from pytorch_optimizer.optimizer import SAM, WSAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer +from pytorch_optimizer.optimizer.galore import GaLoreProjector from tests.constants import PULLBACK_MOMENTUM from tests.utils import Example, simple_parameter, simple_zero_rank_parameter diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 0bced2f20..de1cc612f 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -3,23 +3,21 @@ import torch from torch import nn -from pytorch_optimizer import ( +from pytorch_optimizer.base.exception import NoClosureError, ZeroParameterSizeError +from pytorch_optimizer.lr_scheduler import CosineScheduler, ProportionScheduler +from pytorch_optimizer.optimizer import ( BSAM, GSAM, SAM, TRAC, WSAM, - CosineScheduler, DynamicLossScaler, Lookahead, PCGrad, - ProportionScheduler, - gradfilter_ema, - gradfilter_ma, load_optimizer, ) -from pytorch_optimizer.base.exception import NoClosureError, ZeroParameterSizeError -from pytorch_optimizer.optimizer.utils import l2_projection +from pytorch_optimizer.optimizer.alig import l2_projection +from pytorch_optimizer.optimizer.grokfast import gradfilter_ema, gradfilter_ma from tests.constants import ( ADAMD_SUPPORTED_OPTIMIZERS, ADANORM_SUPPORTED_OPTIMIZERS, diff --git a/tests/test_utils.py b/tests/test_utils.py index 7717f1302..8b27f7947 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,26 +5,25 @@ import torch from torch import nn +from pytorch_optimizer.optimizer import get_optimizer_parameters +from pytorch_optimizer.optimizer.nero import neuron_mean, neuron_norm from pytorch_optimizer.optimizer.shampoo_utils import ( BlockPartitioner, PreConditioner, compute_power_schur_newton, merge_small_dims, ) +from pytorch_optimizer.optimizer.sm3 import reduce_max_except_dim from pytorch_optimizer.optimizer.utils import ( CPUOffloadOptimizer, clip_grad_norm, compare_versions, disable_running_stats, enable_running_stats, - get_optimizer_parameters, has_overflow, is_valid_parameters, - neuron_mean, - neuron_norm, normalize_gradient, parse_pytorch_version, - reduce_max_except_dim, reg_noise, to_real, unit_norm, diff --git a/tests/utils.py b/tests/utils.py index bbda34d33..b5538874b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,8 +5,8 @@ from torch import nn from torch.nn import functional as f -from pytorch_optimizer import AdamP, Lookahead from pytorch_optimizer.base.types import LOSS +from pytorch_optimizer.optimizer import AdamW, Lookahead class LogisticRegression(nn.Module): @@ -104,7 +104,7 @@ def dummy_closure() -> LOSS: def build_lookahead(*parameters, **kwargs): - return Lookahead(AdamP(*parameters, **kwargs)) + return Lookahead(AdamW(*parameters, **kwargs)) def ids(v) -> str: