diff --git a/Dockerfile b/Dockerfile index c556c1a95..fdd78662b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/cuda:10.2-cudnn8-devel-ubuntu18.04 +FROM nvcr.io/nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 LABEL maintainer="Learning@home" LABEL repository="hivemind" @@ -21,7 +21,7 @@ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh - bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh ENV PATH="/opt/conda/bin:${PATH}" -RUN conda install python~=3.8 pip && \ +RUN conda install python~=3.11.0 pip && \ pip install --no-cache-dir torch torchvision torchaudio && \ conda clean --all diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index 704f859c9..38bcbb740 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -4,8 +4,17 @@ from typing import Dict, Optional import torch -from torch.cuda.amp import GradScaler as TorchGradScaler -from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state +from packaging import version + +torch_version = torch.__version__.split("+")[0] + +if version.parse(torch_version) >= version.parse("2.3.0"): + from torch.amp import GradScaler as TorchGradScaler + from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state +else: + from torch.cuda.amp import GradScaler as TorchGradScaler + from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state + from torch.optim import Optimizer as TorchOptimizer import hivemind