Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dedup logic to use dedup_save_plans #7929

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions torch_xla/experimental/distributed_checkpoint/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,38 +151,6 @@ def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
return flattened, mappings


# TODO(jonbolin): Take a dependency on the upstream implementation when the APIs
# are stable.
# https://github.com/pytorch/pytorch/blob/d1cecd9c32ba700c27f2b0716bf2cbef41469495/torch/distributed/checkpoint/_dedup_tensors.py#L29
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
all_plans = list(all_plans)
key_to_plan: Dict[MetadataIndex, List[int]] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
key_to_plan.setdefault(write_item.index, []).append(plan_idx)

replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}

# Remove duplicates by always keeping the first entry.
# Compute the per-rank remove set.
plan_to_keys: Dict[int, List[MetadataIndex]] = {}
for key, plans in replicated_items.items():
for plan_idx in plans[1:]:
plan_to_keys.setdefault(plan_idx, []).append(key)

for plan_idx, keys in plan_to_keys.items():
key_set = set(keys)
# rewrite items and remove elements
new_items = [
write_item for write_item in all_plans[plan_idx].items
if write_item.index not in key_set
]
all_plans[plan_idx] = dataclasses.replace(
all_plans[plan_idx], items=new_items)

return all_plans


# TODO(jonbolin): Take a dependency on the upstream implementation when the APIs
# are stable
# https://github.com/pytorch/pytorch/blob/d1cecd9c32ba700c27f2b0716bf2cbef41469495/torch/distributed/_shard/_utils.py#L7
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/experimental/distributed_checkpoint/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from torch.utils._pytree import tree_map
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
from torch_xla.experimental.distributed_checkpoint._helpers import (
FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor,
set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards)
FLATTEN_MAPPING, flatten_state_dict, _is_sharded_tensor, set_element,
jonb377 marked this conversation as resolved.
Show resolved Hide resolved
narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards)
from typing import Any, Dict, List, Tuple, Union
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans


class SPMDSavePlanner(SavePlanner):
Expand Down Expand Up @@ -107,7 +108,7 @@ def create_local_plan(self) -> SavePlan:
def create_global_plan(
self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
# Deduplicate write items across plans
all_plans = dedup_tensors(all_plans)
all_plans = dedup_save_plans(all_plans)

global_plan, metadata = create_default_global_save_plan(
all_plans, rewrite_index_hints=False)
Expand Down
Loading