Skip to content

Commit

Permalink
Fix lint issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Jan 8, 2024
1 parent e436719 commit 3d0d771
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions torch_xla/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,39 @@
from typing import Iterable, List, Tuple, Union


def get_device_states(*args) -> Tuple[List[torch.device], List[Union[torch.Tensor, int]]]:
fwd_device = list(
{
def get_device_states(
*args) -> Tuple[List[torch.device], List[Union[torch.Tensor, int]]]:
fwd_device = list({
arg.device
for arg in args
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
}
)
})

fwd_device_states = []

device_type = _infer_device_type(*args)
assert device_type is not None, "multiple non-CPU devices not supported"

device_module = xm if device_type == "xla" else _get_device_module(device_type)
device_module = xm if device_type == "xla" else _get_device_module(
device_type)

for device in fwd_device:
fwd_device_states.append(device_module.get_rng_state(device))

return fwd_device, fwd_device_states


def set_device_states(devices: List[torch.device], states: List[Union[torch.Tensor, int]]) -> None:
def set_device_states(devices: List[torch.device],
states: List[Union[torch.Tensor, int]]) -> None:
if len(states) == 0:
return

# get_device_states guarantees that there's only one device-type.
# Therefore, their states should also be of either Tensor (cuda) or int (xla) type.

state_0_type = type(states[0])
assert all(isinstance(v, state_0_type) for v in states), f"all device states should have the same type"
assert all(isinstance(v, state_0_type)
for v in states), f"all device states should have the same type"

device_module = xm if state_0_type == int else get_device_module(*states)
for device, state in zip(devices, states):
Expand Down

0 comments on commit 3d0d771

Please sign in to comment.