diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 185ba0e04c4..aebe6e1c77a 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -32,6 +32,7 @@ SCHEDULER_NAME, WEIGHTS_NAME, get_pretty_name, + is_mlu_available, is_torch_xla_available, is_xpu_available, save, @@ -143,6 +144,8 @@ def save_accelerator_state( states["torch_manual_seed"] = torch.get_rng_state() if is_xpu_available(): states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() + if is_mlu_available(): + states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all() else: states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() if is_torch_xla_available(): @@ -255,6 +258,8 @@ def load_accelerator_state( torch.set_rng_state(states["torch_manual_seed"]) if is_xpu_available(): torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) + if is_mlu_available(): + torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"]) else: torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) if is_torch_xla_available(): diff --git a/src/accelerate/commands/env.py b/src/accelerate/commands/env.py index 7078c6c0adc..7dd5995f6b4 100644 --- a/src/accelerate/commands/env.py +++ b/src/accelerate/commands/env.py @@ -81,6 +81,8 @@ def env_command(args): } if pt_cuda_available: info["GPU type"] = torch.cuda.get_device_name() + if pt_mlu_available: + info["MLU type"] = torch.mlu.get_device_name() if pt_npu_available: info["CANN version"] = torch.version.cann diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index b12fde45bcf..cf41bc76b62 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -32,7 +32,7 @@ from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE from .environment import str_to_bool -from .imports import is_cuda_available, is_npu_available, is_xpu_available +from .imports import is_cuda_available, is_mlu_available, is_npu_available, is_xpu_available from .versions import compare_versions @@ -1341,6 +1341,8 @@ def __post_init__(self): if self.sync_module_states: if is_npu_available(): device = torch.npu.current_device() + elif is_mlu_available(): + device = torch.mlu.current_device() elif is_cuda_available(): device = torch.cuda.current_device() elif is_xpu_available():