From 78cfaed4f030f323105fc5e713f00bcc74df4e40 Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Fri, 7 Jun 2024 04:58:12 -0500 Subject: [PATCH] add if/else check to handle differences between torch versions --- hivemind/optim/grad_scaler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index 2ce090136..5f8cdaf90 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -4,8 +4,17 @@ from typing import Dict, Optional import torch -from torch.amp import GradScaler as TorchGradScaler -from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state +from packaging import version + +torch_version = torch.__version__.split('+')[0] + +if version.parse(torch_version) >= version.parse("2.3.0"): + from torch.amp import GradScaler as TorchGradScaler + from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state +else: + from torch.cuda.amp import GradScaler as TorchGradScaler + from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state + from torch.optim import Optimizer as TorchOptimizer import hivemind