Skip to content

Commit

Permalink
py1e randomized (#442)
Browse files Browse the repository at this point in the history
* py1e randomized

* py1e randomized

* py1e randomized

* py1e randomized
  • Loading branch information
snarayan21 authored Sep 21, 2023
1 parent 5bfdbc8 commit d5ff35f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docs/source/fundamentals/shuffling.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ In order to improve shuffle quality, this algorithm requires more shards to be d

### py1e

Globally shuffle shards, divide that sample space over canonical nodes, shuffle the samples in each shard, then randomly distribute the samples from each shard over an expanded range (given by `shuffle_block_size`). So named because it shuffles samples by extending the range of each shard, in python.
Globally shuffle shards, divide that sample space over canonical nodes, shuffle the samples in each shard, then randomly distribute the samples from each shard over an expanded range (determined using `shuffle_block_size`). So named because it shuffles samples by extending the range of each shard, in python.

Shuffle block size should be set larger or much larger than the number of samples in a single shard. This algorithm provides guaranteed bounds on the range that samples from a shard can appear, allowing for a lower cache limit without decreasing throughput compared to py1b.
Shuffle block size should be set larger or much larger than the number of samples in a single shard. This algorithm provides bounds on the range that samples from a shard can appear, allowing for a lower cache limit without decreasing throughput compared to py1b.

This algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard, similar to py1b. However, these shards will be downloaded in a more balanced fashion, reducing network bandwidth bottlenecks.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
install_requires = [
'boto3>=1.21.45,<2',
'Brotli>=1.0.9',
'google-cloud-storage>=2.9.0',
'google-cloud-storage>=2.9.0,<2.11.0',
'matplotlib>=3.5.2,<4',
'paramiko>=2.11.0,<4',
'python-snappy>=0.6.1,<1',
Expand Down
15 changes: 10 additions & 5 deletions streaming/base/shuffle/py1e.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""Shuffling algorithm that shuffles by randomly placing shard samples in expanded ranges.
This algorithm has more balanced downloading and a lower minimum cache limit than ``py1b`` and
``py1br``, but also slightly lower shuffle quality. The maximum range the samples from each shard
can cover is determined by ``shuffle_block_size``.
``py1br``, but also slightly lower shuffle quality. The range the samples from each shard can cover
is determined by ``shuffle_block_size``.
"""

import numpy as np
Expand Down Expand Up @@ -85,9 +85,14 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
sample_positions = np.arange(num_cn_samples).astype(np.float64)
for span_size in cn_span_sizes:

# The maximum range on each side of the span is (block_size - span_size) / 2.
# This ensures that the span samples are only found in a range of maximum block_size.
cutoff = (block_size - span_size) / 2
# Sample the block size uniformly in a fixed range centered around the block_size.
# This helps to ensure that when training across a large number of nodes, downloads
# are more balanced.
rand_block_size = epoch_rng.integers(int(0.75 * block_size), int(1.25 * block_size))

# The maximum range on each side of the span is (rand_block_size - span_size) / 2.
# This ensures that the span samples are only found in a max range of rand_block_size.
cutoff = (rand_block_size - span_size) / 2

# Make sure the lower bound of the range doesn't cross the start of the canonical node.
lower_bound = max(-cutoff, -cn_sample_offset)
Expand Down

0 comments on commit d5ff35f

Please sign in to comment.