Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Clean-up and fix RTD enum build issue (#1262)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Mar 29, 2022
1 parent a59e10a commit 2a09ce0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
3 changes: 2 additions & 1 deletion docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ Under the hood, the pseudocode looks like:
unfreeze_milestones
-------------------
This strategy allows you to unfreeze part of the backbone at predetermined intervals

This strategy allows you to unfreeze part of the backbone at predetermined intervals.

Here's an example where:

Expand Down
29 changes: 15 additions & 14 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from typing import Iterable, Optional, Tuple, Union

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.nn import Module
from torch.optim import Optimizer

from flash.core.registry import FlashRegistry

if not os.environ.get("READTHEDOCS", False):
from pytorch_lightning.utilities.enums import LightningEnum
else:
# ReadTheDocs mocks the `LightningEnum` import to be a regular type, so we replace it with a plain Enum here.
from enum import Enum

LightningEnum = Enum


class FinetuningStrategies(LightningEnum):
"""The ``FinetuningStrategies`` enum contains the keys that are used internally by the ``FlashBaseFinetuning``
Expand Down Expand Up @@ -63,7 +71,8 @@ def __init__(

if self.strategy == FinetuningStrategies.FREEZE_UNFREEZE and not isinstance(self.strategy_metadata, int):
raise MisconfigurationException(
"`freeze_unfreeze` stratgey only accepts one integer denoting the epoch number to switch."
"The `freeze_unfreeze` strategy requires an integer denoting the epoch number to unfreeze at. Example: "
"`strategy=('freeze_unfreeze', 7)`"
)
if self.strategy == FinetuningStrategies.UNFREEZE_MILESTONES and not (
isinstance(self.strategy_metadata, Tuple)
Expand All @@ -73,8 +82,8 @@ def __init__(
and isinstance(self.strategy_metadata[0][1], int)
):
raise MisconfigurationException(
"`unfreeze_milestones` strategy only accepts the format Tuple[Tuple[int, int], int]. HINT example: "
"((5, 10), 15)."
"The `unfreeze_milestones` strategy requires the format Tuple[Tuple[int, int], int]. Example: "
"`strategy=('unfreeze_milestones', ((5, 10), 15))`"
)

def _get_modules_to_freeze(self, pl_module: LightningModule) -> Union[Module, Iterable[Union[Module, Iterable]]]:
Expand Down Expand Up @@ -158,19 +167,11 @@ def finetune_function(
self._unfreeze_milestones_function(pl_module, epoch, optimizer, opt_idx, self.strategy_metadata)


# Used for properly verifying input and providing neat and helpful error messages for users.
_DEFAULTS_FINETUNE_STRATEGIES = [
FinetuningStrategies.NO_FREEZE.value,
FinetuningStrategies.FREEZE.value,
FinetuningStrategies.FREEZE_UNFREEZE.value,
FinetuningStrategies.UNFREEZE_MILESTONES.value,
]

_FINETUNING_STRATEGIES_REGISTRY = FlashRegistry("finetuning_strategies")

for strategy in _DEFAULTS_FINETUNE_STRATEGIES:
for strategy in FinetuningStrategies:
_FINETUNING_STRATEGIES_REGISTRY(
name=strategy,
name=strategy.value,
fn=partial(FlashBaseFinetuning, strategy_key=strategy),
)

Expand Down
13 changes: 6 additions & 7 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.output import BASE_OUTPUTS
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, _FINETUNING_STRATEGIES_REGISTRY
from flash.core.finetuning import _FINETUNING_STRATEGIES_REGISTRY
from flash.core.hooks import FineTuningHooks
from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY
from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY
Expand Down Expand Up @@ -531,7 +531,7 @@ def configure_finetune_callback(
if isinstance(strategy, str):
if strategy not in self.available_finetuning_strategies():
raise MisconfigurationException(
f"Please provide a valid strategy from {_DEFAULTS_FINETUNE_STRATEGIES[:2]}."
f"The `strategy` should be one of: {', '.join(self.available_finetuning_strategies())}."
" For more details and advanced finetuning options see our docs:"
" https://lightning-flash.readthedocs.io/en/stable/general/finetuning.html"
)
Expand All @@ -540,16 +540,15 @@ def configure_finetune_callback(
elif isinstance(strategy, Tuple):
if not isinstance(strategy[0], str) or strategy[0] not in self.available_finetuning_strategies():
raise MisconfigurationException(
f"First input of `strategy` in a tuple configuration should be a string within"
f" {_DEFAULTS_FINETUNE_STRATEGIES[3:]}"
f"The first input of `strategy` in a tuple configuration should be one of:"
f" {', '.join(self.available_finetuning_strategies())}."
)
finetuning_strategy_fn: Callable = self.finetuning_strategies.get(key=strategy[0])
finetuning_strategy_metadata = {"strategy_metadata": strategy[1], "train_bn": train_bn}
else:
raise MisconfigurationException(
"`strategy` should be a ``pytorch_lightning.callbacks.BaseFinetuning``"
f"callback or a str within {list(_DEFAULTS_FINETUNE_STRATEGIES[:3])}"
f"or a tuple configuration with {list(_DEFAULTS_FINETUNE_STRATEGIES[3:])}"
"The `strategy` should be a ``pytorch_lightning.callbacks.BaseFinetuning`` callback or one of: "
f"{', '.join(self.available_finetuning_strategies())}."
)

return [finetuning_strategy_fn(**finetuning_strategy_metadata)]
Expand Down

0 comments on commit 2a09ce0

Please sign in to comment.