Skip to content

Commit

Permalink
add if/else check to handle differences between torch versions
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jun 7, 2024
1 parent 8801689 commit 78cfaed
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions hivemind/optim/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
from typing import Dict, Optional

import torch
from torch.amp import GradScaler as TorchGradScaler
from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state
from packaging import version

torch_version = torch.__version__.split('+')[0]

if version.parse(torch_version) >= version.parse("2.3.0"):
from torch.amp import GradScaler as TorchGradScaler
from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state
else:
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state

from torch.optim import Optimizer as TorchOptimizer

import hivemind
Expand Down

0 comments on commit 78cfaed

Please sign in to comment.