Skip to content

Commit

Permalink
fix(nyz): fix ttorch prev_state to device bug (#561)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Dec 21, 2022
1 parent bc08a37 commit c30818a
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion ding/torch_utils/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
if isinstance(item, torch.nn.Module):
return item.to(device)
elif isinstance(item, ttorch.Tensor):
return item.to(device)
if 'prev_state' in item:
prev_state = to_device(item.prev_state, device)
del item.prev_state
item = item.to(device)
item.prev_state = prev_state
return item
else:
return item.to(device)
elif isinstance(item, torch.Tensor):
return item.to(device)
elif isinstance(item, Sequence):
Expand Down

0 comments on commit c30818a

Please sign in to comment.