diff --git a/torch_xla/utils/checkpoint.py b/torch_xla/utils/checkpoint.py index 4f16ba18eb2..66922f1fff1 100644 --- a/torch_xla/utils/checkpoint.py +++ b/torch_xla/utils/checkpoint.py @@ -8,21 +8,21 @@ from typing import Iterable, List, Tuple, Union -def get_device_states(*args) -> Tuple[List[torch.device], List[Union[torch.Tensor, int]]]: - fwd_device = list( - { +def get_device_states( + *args) -> Tuple[List[torch.device], List[Union[torch.Tensor, int]]]: + fwd_device = list({ arg.device for arg in args if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu" - } - ) + }) fwd_device_states = [] device_type = _infer_device_type(*args) assert device_type is not None, "multiple non-CPU devices not supported" - device_module = xm if device_type == "xla" else _get_device_module(device_type) + device_module = xm if device_type == "xla" else _get_device_module( + device_type) for device in fwd_device: fwd_device_states.append(device_module.get_rng_state(device)) @@ -30,7 +30,8 @@ def get_device_states(*args) -> Tuple[List[torch.device], List[Union[torch.Tenso return fwd_device, fwd_device_states -def set_device_states(devices: List[torch.device], states: List[Union[torch.Tensor, int]]) -> None: +def set_device_states(devices: List[torch.device], + states: List[Union[torch.Tensor, int]]) -> None: if len(states) == 0: return @@ -38,7 +39,8 @@ def set_device_states(devices: List[torch.device], states: List[Union[torch.Tens # Therefore, their states should also be of either Tensor (cuda) or int (xla) type. state_0_type = type(states[0]) - assert all(isinstance(v, state_0_type) for v in states), f"all device states should have the same type" + assert all(isinstance(v, state_0_type) + for v in states), f"all device states should have the same type" device_module = xm if state_0_type == int else get_device_module(*states) for device, state in zip(devices, states):