Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MLU devices for rng state saving and loading. #2940

Merged
merged 36 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
bc5ccfb
Add Cambricon MLU accelerator support
huismiling Mar 13, 2024
3ad38dc
up mlu support for test
huismiling Mar 13, 2024
be32c91
fix mlu device MULTI_MLU
huismiling Mar 13, 2024
421c142
Update src/accelerate/utils/imports.py
huismiling Mar 14, 2024
78cd1cb
up mlu for quality check
huismiling Mar 14, 2024
3abd038
fix mlu device longTensor error
huismiling Mar 15, 2024
0542987
fix mlu device tensor dtype check
huismiling Mar 19, 2024
e024276
fix mlu device send_to_device with torch dynamo error
huismiling Mar 19, 2024
a50b0d9
Refactor AcceleratorState
muellerzr Mar 21, 2024
ff628be
Should be near complete now
muellerzr Mar 21, 2024
a1aac83
Last missing piece
muellerzr Mar 21, 2024
31ea8cc
Make my way to the acceleratorstate
muellerzr Mar 21, 2024
47b60ca
Include update to global var
muellerzr Mar 21, 2024
2082a9a
Don't use global
muellerzr Mar 21, 2024
26c484e
gpu -> cuda
muellerzr Mar 21, 2024
5ac5d56
Don't use update for dict, easier to read
muellerzr Mar 21, 2024
2baa5c3
Fix tests
muellerzr Mar 21, 2024
d709f66
stash
muellerzr Mar 21, 2024
ac24315
Getting closer...
muellerzr Mar 21, 2024
1628898
Needed to spawn at the very end after env was setup
muellerzr Mar 21, 2024
6958e1b
Explain set_device before deepspeed
muellerzr Mar 22, 2024
2b9d339
Make docstring more accurate
muellerzr Mar 22, 2024
194db93
Early return insteaD
muellerzr Mar 22, 2024
31201d3
Delineat blocks
muellerzr Mar 22, 2024
eef1aa0
Make prepare_backend return state + backend for clarity/less magic
muellerzr Mar 22, 2024
37d0edc
Merge branch 'huggingface:main' into main
huismiling Mar 25, 2024
0fc1df3
Merge remote-tracking branch 'hf-acc/refactor-state'
huismiling Mar 25, 2024
b09003c
Merge branch 'huggingface:main' into main
huismiling May 8, 2024
92dc4bc
merge from hf
huismiling May 8, 2024
124331a
fix mlu longtensor.to() bugs.
huismiling May 8, 2024
36f35e8
Merge branch 'huggingface:main' into main
huismiling May 20, 2024
48d2c0c
Merge branch 'huggingface:main' into main
huismiling May 23, 2024
900efd0
Merge branch 'huggingface:main' into main
huismiling May 29, 2024
b3a1aed
Merge branch 'huggingface:main' into main
huismiling Jun 25, 2024
ef86bf2
Merge branch 'huggingface:main' into main
huismiling Jul 18, 2024
012f7a3
fix MLU devices rng state save and load.
huismiling Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SCHEDULER_NAME,
WEIGHTS_NAME,
get_pretty_name,
is_mlu_available,
is_torch_xla_available,
is_xpu_available,
save,
Expand Down Expand Up @@ -143,6 +144,8 @@ def save_accelerator_state(
states["torch_manual_seed"] = torch.get_rng_state()
if is_xpu_available():
states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
if is_mlu_available():
states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
else:
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
if is_torch_xla_available():
Expand Down Expand Up @@ -255,6 +258,8 @@ def load_accelerator_state(
torch.set_rng_state(states["torch_manual_seed"])
if is_xpu_available():
torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
if is_mlu_available():
torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
else:
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
if is_torch_xla_available():
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/commands/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def env_command(args):
}
if pt_cuda_available:
info["GPU type"] = torch.cuda.get_device_name()
if pt_mlu_available:
info["MLU type"] = torch.mlu.get_device_name()
if pt_npu_available:
info["CANN version"] = torch.version.cann

Expand Down
4 changes: 3 additions & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .imports import is_cuda_available, is_npu_available, is_xpu_available
from .imports import is_cuda_available, is_mlu_available, is_npu_available, is_xpu_available
from .versions import compare_versions


Expand Down Expand Up @@ -1341,6 +1341,8 @@ def __post_init__(self):
if self.sync_module_states:
if is_npu_available():
device = torch.npu.current_device()
elif is_mlu_available():
device = torch.mlu.current_device()
elif is_cuda_available():
device = torch.cuda.current_device()
elif is_xpu_available():
Expand Down
Loading