Skip to content

Commit

Permalink
Small fixed for balanced device maps (#583)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored Jul 28, 2022
1 parent 7f5c60c commit d5a0fc2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
13 changes: 4 additions & 9 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,17 +405,12 @@ def get_balanced_memory(
leaves = [n for n in module_sizes if len([p for p in module_sizes if p.startswith(n) and len(p) > len(n)]) == 0]
mean_leaves = int(sum([module_sizes[n] for n in leaves]) / len(leaves))
buffer = int(1.25 * max(buffer, mean_leaves))
if low_zero:
per_gpu += buffer
gpu_zero = 0
else:
gpu_zero = per_gpu
per_gpu += buffer
per_gpu += buffer

max_memory = get_max_memory(max_memory)
for i in range(num_devices):
# We still leave slightly more space on GPU 0 and only apply the buffer on the other devices.
max_memory[i] = min(gpu_zero if i == 0 else per_gpu, max_memory[i])
# The last device is left with max_memory just in case the buffer is not enough.
for i in range(num_devices - 1):
max_memory[i] = min(0 if low_zero and i == 0 else per_gpu, max_memory[i])

if low_zero:
min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)]))
Expand Down
8 changes: 6 additions & 2 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,11 @@ def test_get_balanced_memory(self):
model = ModelForTest()
# model has size 236: linear1 64, batchnorm 72, linear2 100
max_memory = get_balanced_memory(model, max_memory={0: 200, 1: 200})
self.assertDictEqual({0: 118, 1: 200}, max_memory)
self.assertDictEqual({0: 200, 1: 200}, max_memory)

max_memory = get_balanced_memory(model, max_memory={0: 300, 1: 300})
self.assertDictEqual({0: 118, 1: 215}, max_memory)
self.assertDictEqual({0: 215, 1: 300}, max_memory)

# Last device always get max memory to give more buffer and avoid accidental CPU offload
max_memory = get_balanced_memory(model, max_memory={0: 300, 1: 500})
self.assertDictEqual({0: 215, 1: 500}, max_memory)

0 comments on commit d5a0fc2

Please sign in to comment.