Skip to content

Commit

Permalink
Add Ascend NPU Support
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed Oct 26, 2024
1 parent 2501c1a commit 080f4eb
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 15 deletions.
12 changes: 11 additions & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from pynvml.nvml import NVMLError

from axolotl.utils.distributed import CURRENT_DEVICE


def check_cuda_device(default_value):
"""
Expand Down Expand Up @@ -53,6 +55,12 @@ def mps_memory_usage_all():
return usage, reserved - usage, 0


def npu_memory_usage_all(device=0):
usage = torch.npu.memory_allocated(device) / 1024.0**3
reserved = torch.npu.memory_reserved(device) / 1024.0**3
return usage, reserved - usage, 0


@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
Expand All @@ -71,6 +79,8 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in CURRENT_DEVICE.__str__():
usage, cache, misc = npu_memory_usage_all(device)
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
Expand All @@ -79,6 +89,6 @@ def log_gpu_memory_usage(log, msg, device):
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
f"{CURRENT_DEVICE.__str__()} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc
26 changes: 25 additions & 1 deletion src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available

from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
Expand All @@ -32,7 +33,10 @@ def get_device():
if torch.backends.mps.is_available():
return "mps"

raise SystemError("No CUDA/mps device found")
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"

raise SystemError("No CUDA/mps/npu device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"

Expand All @@ -42,6 +46,8 @@ def get_device():
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"npu": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}

Expand Down Expand Up @@ -94,6 +100,24 @@ def normalize_config(cfg):
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
elif cfg.device.startswith("npu"):
if cfg.load_in_8bit or cfg.load_in_4bit:
LOG.warning("Quantification is currently not supported in Ascend npu, disabling for this configuration.")
cfg.load_in_8bit = False
cfg.load_in_4bit = False

if cfg.tf32:
LOG.warning("tf32 dtype is currently not supported in Ascend npu, disabling for this configuration.")
cfg.tf32 = False

if cfg.flash_attention:
LOG.error("flash_attn is currently not supported in Ascend npu, disabling for this configuration.")
cfg.flash_attention = False

if "bit" in cfg.optimizer:
LOG.error("{} is currently not supported in Ascend npu, choose another one.".format(cfg.optimizer))
raise NotImplementedError

else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16:
Expand Down
34 changes: 26 additions & 8 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,27 @@
import torch.distributed as dist
from accelerate import PartialState

from transformers.utils.import_utils import (
is_torch_npu_available,
is_torch_cuda_available,
is_torch_mps_available
)

distributed_state = None # pylint: disable=invalid-name


def get_device():
device = torch.device("cpu")
if is_torch_cuda_available():
device = torch.device("cuda")
elif is_torch_mps_available():
device = torch.device("mps")
elif is_torch_npu_available():
device = torch.device("npu")
return device

CURRENT_DEVICE = get_device()

def is_distributed():
"""
Check if distributed training is initialized.
Expand Down Expand Up @@ -91,7 +109,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=CURRENT_DEVICE
).float()

if not is_main_process():
Expand All @@ -117,11 +135,11 @@ def broadcast_dict(vals: dict):

if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
data_tensor = torch.ByteTensor(list(data_byte)).to(CURRENT_DEVICE)
data_size = torch.IntTensor([len(data_byte)]).to(CURRENT_DEVICE)
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
data_tensor = torch.empty([1024], dtype=torch.uint8, device=CURRENT_DEVICE)
data_size = torch.IntTensor([0]).to(CURRENT_DEVICE)

dist.broadcast(data_size, 0)
if not is_main_process():
Expand Down Expand Up @@ -153,11 +171,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
value_scalar, device=CURRENT_DEVICE, dtype=torch.float32
)
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device(), dtype=torch.float32
0.0, device=CURRENT_DEVICE, dtype=torch.float32
) # Placeholder tensor

# Broadcast the tensor to all processes.
Expand All @@ -184,7 +202,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=CURRENT_DEVICE
).float()

# Placeholder tensor for gathering results
Expand Down
21 changes: 16 additions & 5 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import zero_only, CURRENT_DEVICE
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
Expand Down Expand Up @@ -324,6 +324,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
return processor


def get_device_count():
if "cuda" in CURRENT_DEVICE.__str__():
return torch.cuda.device_count()
elif "npu" in CURRENT_DEVICE.__str__():
return torch.npu.device_count()
return 1

class ModelLoader:
"""
ModelLoader: managing all the config and monkey patches while loading model
Expand Down Expand Up @@ -556,7 +563,8 @@ def set_device_map_config(self) -> None:
)

max_memory = {}
for i in range(torch.cuda.device_count()):
num_device = get_device_count()
for i in range(num_device):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything

Expand All @@ -583,6 +591,8 @@ def set_device_map_config(self) -> None:

if torch.backends.mps.is_available():
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in CURRENT_DEVICE.__str__():
self.model_kwargs["device_map"] = "npu:0"

# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
Expand Down Expand Up @@ -1010,7 +1020,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
self.ajust_model_config()

# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps", "npu"):
log_gpu_memory_usage(LOG, "after model load", self.model.device)

# make sure these are fp32 per Ramesh et al. (2021)
Expand Down Expand Up @@ -1084,9 +1094,10 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
and not skip_move_to_device
):
# TODO revaldate this conditional
self.model.to(f"cuda:{self.cfg.local_rank}")
print(120*"*", f"{CURRENT_DEVICE.__str__()}:{self.cfg.local_rank}")
self.model.to(f"{CURRENT_DEVICE.__str__()}:{self.cfg.local_rank}")

if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(self.model, "is_parallelizable", True)
setattr(self.model, "model_parallel", True)

Expand Down

0 comments on commit 080f4eb

Please sign in to comment.