From caa18003abaebecd1c7cafeddb8eaa30341cfec7 Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 15 Oct 2024 13:38:31 +0000 Subject: [PATCH] use torch.optim.optimizer.ParamsT in training_utils --- src/refiners/training_utils/config.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 57921e2ea..3df5a1c0c 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -1,27 +1,20 @@ from enum import Enum from logging import warning from pathlib import Path -from typing import Annotated, Any, Callable, Iterable, Literal, Type, TypeVar +from typing import Annotated, Callable, Literal, Type, TypeVar import tomli from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from prodigyopt import Prodigy # type: ignore from pydantic import BaseModel, BeforeValidator, ConfigDict -from torch import Tensor from torch.optim.adam import Adam from torch.optim.adamw import AdamW -from torch.optim.optimizer import Optimizer +from torch.optim.optimizer import Optimizer, ParamsT from torch.optim.sgd import SGD from refiners.training_utils.clock import ClockConfig from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field -# PyTorch optimizer parameters type -# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced -# See https://github.com/pytorch/pytorch/pull/111114 -ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] - - TimeValueField = Annotated[TimeValue, BeforeValidator(parse_number_unit_field)] IterationOrEpochField = Annotated[Iteration | Epoch, BeforeValidator(parse_number_unit_field)] StepField = Annotated[Step, BeforeValidator(parse_number_unit_field)]