Skip to content

Commit

Permalink
Merge pull request #19 from aljoscha/hacking-job-server
Browse files Browse the repository at this point in the history
Fix graph traversal and Create()
  • Loading branch information
axelmagn authored Mar 19, 2018
2 parents ccf9614 + 6b248ab commit a412c54
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;

import static com.google.common.base.Preconditions.checkState;

/** A {@link Pipeline} which has been separated into collections of executable components. */
@AutoValue
public abstract class FusedPipeline {
Expand Down Expand Up @@ -78,7 +80,15 @@ public RunnerApi.Pipeline toPipeline(Components initialComponents) {
private Map<String, PTransform> getTopLevelTransforms(Components base) {
Map<String, PTransform> topLevelTransforms = new HashMap<>();
for (PTransformNode runnerExecuted : getRunnerExecutedTransforms()) {
topLevelTransforms.put(runnerExecuted.getId(), runnerExecuted.getTransform());
PTransform extant = topLevelTransforms.put(
runnerExecuted.getId(),
runnerExecuted.getTransform());
checkState(
extant == null,
"Transforms with ID %s collided: %s and %s",
runnerExecuted.getId(),
extant,
runnerExecuted.getTransform());
}
for (ExecutableStage stage : getFusedStages()) {
topLevelTransforms.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.graph.MutableNetwork;
import com.google.common.graph.Network;
import com.google.common.graph.NetworkBuilder;
Expand Down Expand Up @@ -203,14 +205,18 @@ public Long load(@Nonnull PTransformNode transformNode) {
});

public Iterable<PTransformNode> getTopologicallyOrderedTransforms() {
Comparator<PTransformNode> cmp = (left, right) -> ComparisonChain.start()
.compare(nodeWeights.getUnchecked(left), nodeWeights.getUnchecked(right))
.compare(left.getId(), right.getId())
.result();
return pipelineNetwork
.nodes()
.stream()
.filter(node -> node instanceof PTransformNode)
.map(PTransformNode.class::cast)
.collect(
Collectors.toCollection(
() -> new TreeSet<>(Comparator.comparingLong(nodeWeights::getUnchecked))));
() -> new TreeSet<>(cmp)));
}

/**
Expand Down
76 changes: 1 addition & 75 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,81 +1713,7 @@ def _create_source_from_iterable(values, coder):

@staticmethod
def _create_source(serialized_values, coder):
from apache_beam.io import iobase

class _CreateSource(iobase.BoundedSource):
def __init__(self, serialized_values, coder):
self._coder = coder
self._serialized_values = []
self._total_size = 0
self._serialized_values = serialized_values
self._total_size = sum(map(len, self._serialized_values))

def read(self, range_tracker):
start_position = range_tracker.start_position()
current_position = start_position

def split_points_unclaimed(stop_position):
if current_position >= stop_position:
return 0
return stop_position - current_position - 1

range_tracker.set_split_points_unclaimed_callback(
split_points_unclaimed)
element_iter = iter(self._serialized_values[start_position:])
for i in range(start_position, range_tracker.stop_position()):
if not range_tracker.try_claim(i):
return
current_position = i
yield self._coder.decode(next(element_iter))

def split(self, desired_bundle_size, start_position=None,
stop_position=None):
from apache_beam.io import iobase

if len(self._serialized_values) < 2:
yield iobase.SourceBundle(
weight=0, source=self, start_position=0,
stop_position=len(self._serialized_values))
else:
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = len(self._serialized_values)

avg_size_per_value = self._total_size / len(self._serialized_values)
num_values_per_split = max(
int(desired_bundle_size / avg_size_per_value), 1)

start = start_position
while start < stop_position:
end = min(start + num_values_per_split, stop_position)
remaining = stop_position - end
# Avoid having a too small bundle at the end.
if remaining < (num_values_per_split / 4):
end = stop_position

sub_source = Create._create_source(
self._serialized_values[start:end], self._coder)

yield iobase.SourceBundle(weight=(end - start),
source=sub_source,
start_position=0,
stop_position=(end - start))

start = end

def get_range_tracker(self, start_position, stop_position):
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = len(self._serialized_values)

from apache_beam import io
return io.OffsetRangeTracker(start_position, stop_position)

def estimate_size(self):
return self._total_size
from apache_beam.transforms.create_source import _CreateSource

return _CreateSource(serialized_values, coder)

Expand Down
65 changes: 65 additions & 0 deletions sdks/python/apache_beam/transforms/create_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from apache_beam.io import iobase
from apache_beam.transforms.core import Create

class _CreateSource(iobase.BoundedSource):
"""Internal source that is used by Create()"""

def __init__(self, serialized_values, coder):
self._coder = coder
self._serialized_values = []
self._total_size = 0
self._serialized_values = serialized_values
self._total_size = sum(map(len, self._serialized_values))
def read(self, range_tracker):
start_position = range_tracker.start_position()
current_position = start_position
def split_points_unclaimed(stop_position):
if current_position >= stop_position:
return 0
return stop_position - current_position - 1
range_tracker.set_split_points_unclaimed_callback(
split_points_unclaimed)
element_iter = iter(self._serialized_values[start_position:])
for i in range(start_position, range_tracker.stop_position()):
if not range_tracker.try_claim(i):
return
current_position = i
yield self._coder.decode(next(element_iter))
def split(self, desired_bundle_size, start_position=None,
stop_position=None):
from apache_beam.io import iobase
if len(self._serialized_values) < 2:
yield iobase.SourceBundle(
weight=0, source=self, start_position=0,
stop_position=len(self._serialized_values))
else:
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = len(self._serialized_values)
avg_size_per_value = self._total_size / len(self._serialized_values)
num_values_per_split = max(
int(desired_bundle_size / avg_size_per_value), 1)
start = start_position
while start < stop_position:
end = min(start + num_values_per_split, stop_position)
remaining = stop_position - end
# Avoid having a too small bundle at the end.
if remaining < (num_values_per_split / 4):
end = stop_position
sub_source = Create._create_source(
self._serialized_values[start:end], self._coder)
yield iobase.SourceBundle(weight=(end - start),
source=sub_source,
start_position=0,
stop_position=(end - start))
start = end
def get_range_tracker(self, start_position, stop_position):
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = len(self._serialized_values)
from apache_beam import io
return io.OffsetRangeTracker(start_position, stop_position)
def estimate_size(self):
return self._total_size

0 comments on commit a412c54

Please sign in to comment.