Skip to content

Commit

Permalink
[reland][Elastic] Skip store barrier and store get in host assign (#1…
Browse files Browse the repository at this point in the history
…36865)

As title this is to reland #136579 as it broke some OSS CI

Differential Revision: [D63542918](https://our.internmc.facebook.com/intern/diff/D63542918/)

Pull Request resolved: #136865
Approved by: https://github.com/atalman
  • Loading branch information
fduwjj authored and pytorchmergebot committed Sep 27, 2024
1 parent ef3142d commit f42e88f
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 71 deletions.
101 changes: 79 additions & 22 deletions test/distributed/elastic/agent/server/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# LICENSE file in the root directory of this source tree.


import functools
import os
import signal
import unittest
import uuid
Expand Down Expand Up @@ -475,6 +477,29 @@ def test_run_unknown_state(self, mock_monitor_workers):
self.assertEqual(1, mock_monitor_workers.call_count)
self.assertEqual(spec.max_restarts, agent._remaining_restarts)

def get_worker_assigned(self, store, role_infos_len, info) -> List[Worker]:
i, role_info = info
spec = self._get_worker_spec(
max_restarts=3,
monitor_interval=0.1,
role=role_info.role,
local_world_size=role_info.local_world_size,
)
agent = TestAgent(spec)
workers = agent._assign_worker_ranks(
store, role_info.rank, role_infos_len, spec
)
return [
(
w.local_rank,
w.role_rank,
w.global_rank,
w.world_size,
w.role_world_size,
)
for w in workers
]

def test_assign_worker_ranks(self):
role_infos = [
_RoleInstanceInfo("parameter_server", 0, 4),
Expand All @@ -485,28 +510,7 @@ def test_assign_worker_ranks(self):
]
store = dist.HashStore()

def f(info) -> List[Worker]:
i, role_info = info
spec = self._get_worker_spec(
max_restarts=3,
monitor_interval=0.1,
role=role_info.role,
local_world_size=role_info.local_world_size,
)
agent = TestAgent(spec)
workers = agent._assign_worker_ranks(
store, role_info.rank, len(role_infos), spec
)
return [
(
w.local_rank,
w.role_rank,
w.global_rank,
w.world_size,
w.role_world_size,
)
for w in workers
]
f = functools.partial(self.get_worker_assigned, store, len(role_infos))

with ThreadPool(len(role_infos)) as pool:
out = pool.map(f, enumerate(role_infos))
Expand Down Expand Up @@ -542,6 +546,59 @@ def f(info) -> List[Worker]:
],
)

def test_assign_worker_ranks_indentical(self):
os.environ["TORCH_SKIP_STORE_BARRIER"] = "1"
role_infos = [
_RoleInstanceInfo("trainer", 0, 4),
_RoleInstanceInfo("trainer", 1, 4),
_RoleInstanceInfo("trainer", 2, 4),
_RoleInstanceInfo("trainer", 3, 4),
_RoleInstanceInfo("trainer", 4, 4),
]
store = dist.HashStore()

f = functools.partial(self.get_worker_assigned, store, len(role_infos))

with ThreadPool(len(role_infos)) as pool:
out = pool.map(f, enumerate(role_infos))

self.assertListEqual(
out,
[
[
(0, 0, 0, 20, 20),
(1, 1, 1, 20, 20),
(2, 2, 2, 20, 20),
(3, 3, 3, 20, 20),
],
[
(0, 4, 4, 20, 20),
(1, 5, 5, 20, 20),
(2, 6, 6, 20, 20),
(3, 7, 7, 20, 20),
],
[
(0, 8, 8, 20, 20),
(1, 9, 9, 20, 20),
(2, 10, 10, 20, 20),
(3, 11, 11, 20, 20),
],
[
(0, 12, 12, 20, 20),
(1, 13, 13, 20, 20),
(2, 14, 14, 20, 20),
(3, 15, 15, 20, 20),
],
[
(0, 16, 16, 20, 20),
(1, 17, 17, 20, 20),
(2, 18, 18, 20, 20),
(3, 19, 19, 20, 20),
],
],
)
os.environ["TORCH_SKIP_STORE_BARRIER"] = "0"

def test_get_event(self):
spec = self._get_worker_spec(max_restarts=1)
agent = TestAgent(spec)
Expand Down
113 changes: 64 additions & 49 deletions torch/distributed/elastic/agent/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,16 @@ def _assign_worker_ranks(
) -> List[Worker]:
"""Determine proper ranks for worker processes.
The rank assignment is done according to the following algorithm:
Fast Path: when all workers have the same role and world size. We calculate
the global rank to be group_rank * group_world_size + local_rank. And the
`role_world_size` is the same as `global_world_size`. No TCP store is used in
this case. This is only enabled when users set the environment variable
`TORCH_ELASTIC_WORKER_IDENTICAL` to 1.
Time complexity: each worker O(1), overall O(1)
Slow Path: when workers have different roles and world sizes. We use the
the following algorithm:
1. Each agent writes its configuration(group_rank, group_world_size
, num_workers) to the common store.
Expand All @@ -577,60 +586,66 @@ def _assign_worker_ranks(
Time complexity: each worker O(1), rank0 O(n), overall O(n)
"""

ROLE_INFO_PREFIX = "torchelastic/role_info/"
ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/"

agent_role_info = _RoleInstanceInfo(
spec.role, group_rank, spec.local_world_size
)
store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize())
if os.environ.get("TORCH_ELASTIC_WORKER_IDENTICAL", "0") == "1":
global_world_size = group_world_size * spec.local_world_size
base_global_rank = group_rank * spec.local_world_size
base_role_rank = base_global_rank
role_world_size = global_world_size
else:
ROLE_INFO_PREFIX = "torchelastic/role_info/"
ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/"

# tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations.
if group_rank == 0:
role_infos_bytes = store.multi_get(
[f"torchelastic/role_info/{i}" for i in range(group_world_size)]
agent_role_info = _RoleInstanceInfo(
spec.role, group_rank, spec.local_world_size
)
role_infos = [
_RoleInstanceInfo.deserialize(info_bytes)
for info_bytes in role_infos_bytes
]

role_sizes = defaultdict(lambda: 0)
global_size = 0
for role_info in role_infos:
role_sizes[role_info.role] += role_info.local_world_size
global_size += role_info.local_world_size

base_global_rank = 0
role_ranks = defaultdict(lambda: 0)

keys = []
values = []
for i, role_info in enumerate(role_infos):
keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}")
values.append(
json.dumps(
[
base_global_rank,
global_size,
role_ranks[role_info.role],
role_sizes[role_info.role],
]
)
store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize())

# tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations.
if group_rank == 0:
role_infos_bytes = store.multi_get(
[f"torchelastic/role_info/{i}" for i in range(group_world_size)]
)
role_infos = [
_RoleInstanceInfo.deserialize(info_bytes)
for info_bytes in role_infos_bytes
]

role_sizes = defaultdict(lambda: 0)
global_size = 0
for role_info in role_infos:
role_sizes[role_info.role] += role_info.local_world_size
global_size += role_info.local_world_size

base_global_rank = 0
role_ranks = defaultdict(lambda: 0)

keys = []
values = []
for i, role_info in enumerate(role_infos):
keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}")
values.append(
json.dumps(
[
base_global_rank,
global_size,
role_ranks[role_info.role],
role_sizes[role_info.role],
]
)
)

base_global_rank += role_info.local_world_size
role_ranks[role_info.role] += role_info.local_world_size
base_global_rank += role_info.local_world_size
role_ranks[role_info.role] += role_info.local_world_size

store.multi_set(keys, values)
store.multi_set(keys, values)

# get will block until the data is available in the store.
(
base_global_rank,
global_world_size,
base_role_rank,
role_world_size,
) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}"))
# get will block until the data is available in the store.
(
base_global_rank,
global_world_size,
base_role_rank,
role_world_size,
) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}"))

workers = []
for local_rank in range(spec.local_world_size):
Expand Down

0 comments on commit f42e88f

Please sign in to comment.