Skip to content

Commit

Permalink
use torch.optim.optimizer.ParamsT in training_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Oct 15, 2024
1 parent 241abfa commit caa1800
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
from enum import Enum
from logging import warning
from pathlib import Path
from typing import Annotated, Any, Callable, Iterable, Literal, Type, TypeVar
from typing import Annotated, Callable, Literal, Type, TypeVar

import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, BeforeValidator, ConfigDict
from torch import Tensor
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.optim.optimizer import Optimizer, ParamsT
from torch.optim.sgd import SGD

from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field

# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
# See https://github.com/pytorch/pytorch/pull/111114
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]


TimeValueField = Annotated[TimeValue, BeforeValidator(parse_number_unit_field)]
IterationOrEpochField = Annotated[Iteration | Epoch, BeforeValidator(parse_number_unit_field)]
StepField = Annotated[Step, BeforeValidator(parse_number_unit_field)]
Expand Down

0 comments on commit caa1800

Please sign in to comment.