Skip to content

Commit

Permalink
Fix incompatibility issue with PyTorch >= 2.2
Browse files Browse the repository at this point in the history
Hard-coded `params_t`.
  • Loading branch information
HuFY-dev authored Mar 29, 2024
1 parent b6ba6cb commit 659d60c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion sparse_autoencoder/optimizer/adam_with_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.optim import Adam
from torch.optim.optimizer import params_t

from sparse_autoencoder.tensor_types import Axis

# params_t was renamed to ParamsT in PyTorch 2.2, which caused import errors
# Copied from PyTorch 2.1
from typing import Union, Iterable, Dict, Any
from typing_extensions import TypeAlias
params_t: TypeAlias = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]

class AdamWithReset(Adam):
"""Adam Optimizer with a reset method.
Expand Down

0 comments on commit 659d60c

Please sign in to comment.