diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index 3d4afdf938f09..5966c64fd725d 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -32,6 +32,9 @@ License. 1.7 1.7 + + -Xlint:all,-serial + diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java index 8cd593979ae87..673237e574e84 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java @@ -64,7 +64,7 @@ public DoFnFunction( @Override public Iterable call(Iterator iter) throws Exception { - ProcCtxt ctxt = new ProcCtxt<>(mFunction); + ProcCtxt ctxt = new ProcCtxt(mFunction); //setup mFunction.startBundle(ctxt); //operation @@ -77,7 +77,7 @@ public Iterable call(Iterator iter) throws Exception { return ctxt.outputs; } - private class ProcCtxt extends DoFn.ProcessContext { + private class ProcCtxt extends DoFn.ProcessContext { private final List outputs = new LinkedList<>(); private I element; @@ -93,7 +93,9 @@ public PipelineOptions getPipelineOptions() { @Override public T sideInput(PCollectionView view) { - return (T) mSideInputs.get(view.getTagInternal()).getValue(); + @SuppressWarnings("unchecked") + T value = (T) mSideInputs.get(view.getTagInternal()).getValue(); + return value; } @Override diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java index f992cf59a3e8f..4105ab61eb1f3 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java @@ -45,9 +45,9 @@ public class EvaluationContext implements EvaluationResult { private final JavaSparkContext jsc; private final Pipeline pipeline; private final SparkRuntimeContext runtime; - private final Map rdds = Maps.newHashMap(); + private final Map> rdds = Maps.newHashMap(); private final Set multireads = Sets.newHashSet(); - private final Map pobjects = Maps.newHashMap(); + private final Map, Object> pobjects = Maps.newHashMap(); public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) { this.jsc = jsc; @@ -68,23 +68,27 @@ SparkRuntimeContext getRuntimeContext() { } I getInput(PTransform transform) { - return (I) pipeline.getInput(transform); + @SuppressWarnings("unchecked") + I input = (I) pipeline.getInput(transform); + return input; } O getOutput(PTransform transform) { - return (O) pipeline.getOutput(transform); + @SuppressWarnings("unchecked") + O output = (O) pipeline.getOutput(transform); + return output; } - void setOutputRDD(PTransform transform, JavaRDDLike rdd) { + void setOutputRDD(PTransform transform, JavaRDDLike rdd) { rdds.put((PValue) getOutput(transform), rdd); } - void setPObjectValue(PObject pobject, Object value) { + void setPObjectValue(PObject pobject, Object value) { pobjects.put(pobject, value); } - JavaRDDLike getRDD(PValue pvalue) { - JavaRDDLike rdd = rdds.get(pvalue); + JavaRDDLike getRDD(PValue pvalue) { + JavaRDDLike rdd = rdds.get(pvalue); if (multireads.contains(pvalue)) { // Ensure the RDD is marked as cached rdd.rdd().cache(); @@ -94,11 +98,11 @@ JavaRDDLike getRDD(PValue pvalue) { return rdd; } - void setRDD(PValue pvalue, JavaRDDLike rdd) { + void setRDD(PValue pvalue, JavaRDDLike rdd) { rdds.put(pvalue, rdd); } - JavaRDDLike getInputRDD(PTransform transform) { + JavaRDDLike getInputRDD(PTransform transform) { return getRDD((PValue) pipeline.getInput(transform)); } @@ -111,11 +115,14 @@ BroadcastHelper getBroadcastHelper(PObject value) { @Override public T get(PObject value) { if (pobjects.containsKey(value)) { - return (T) pobjects.get(value); + @SuppressWarnings("unchecked") + T result = (T) pobjects.get(value); + return result; } if (rdds.containsKey(value)) { - JavaRDDLike rdd = rdds.get(value); + JavaRDDLike rdd = rdds.get(value); //TODO: need same logic from get() method below here for serialization of bytes + @SuppressWarnings("unchecked") T res = (T) Iterables.getOnlyElement(rdd.collect()); pobjects.put(value, res); return res; @@ -130,10 +137,11 @@ public T getAggregatorValue(String named, Class resultType) { @Override public Iterable get(PCollection pcollection) { - JavaRDDLike rdd = getRDD(pcollection); - final Coder coder = pcollection.getCoder(); - JavaRDDLike bytes = rdd.map(CoderHelpers.toByteFunction(coder)); - List clientBytes = bytes.collect(); + @SuppressWarnings("unchecked") + JavaRDDLike rdd = (JavaRDDLike) getRDD(pcollection); + final Coder coder = pcollection.getCoder(); + JavaRDDLike bytesRDD = rdd.map(CoderHelpers.toByteFunction(coder)); + List clientBytes = bytesRDD.collect(); return Iterables.transform(clientBytes, new Function() { @Override public T apply(byte[] bytes) { @@ -142,16 +150,24 @@ public T apply(byte[] bytes) { }); } - PObjectValueTuple getPObjectTuple(PTransform transform) { + PObjectValueTuple getPObjectTuple(PTransform transform) { PObjectTuple pot = (PObjectTuple) pipeline.getInput(transform); PObjectValueTuple povt = PObjectValueTuple.empty(); for (Map.Entry, PObject> e : pot.getAll().entrySet()) { - povt = povt.and((TupleTag) e.getKey(), get(e.getValue())); + povt = and(povt, e); } return povt; } - void setPObjectTuple(PTransform transform, PObjectValueTuple outputValues) { + private PObjectValueTuple and(PObjectValueTuple povt, Map.Entry, PObject> e) { + @SuppressWarnings("unchecked") + TupleTag ttKey = (TupleTag) e.getKey(); + @SuppressWarnings("unchecked") + PObject potValue = (PObject) e.getValue(); + return povt.and(ttKey, get(potValue)); + } + + void setPObjectTuple(PTransform transform, PObjectValueTuple outputValues) { PObjectTuple pot = (PObjectTuple) pipeline.getOutput(transform); for (Map.Entry, PObject> e : pot.getAll().entrySet()) { pobjects.put(e.getValue(), outputValues.get(e.getKey())); diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java index 922ae04aed57b..70facf82bb4d4 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java @@ -50,7 +50,7 @@ class MultiDoFnFunction implements PairFlatMapFunction, TupleT private final DoFn mFunction; private final SparkRuntimeContext mRuntimeContext; - private final TupleTag mMainOutputTag; + private final TupleTag mMainOutputTag; private final Map, BroadcastHelper> mSideInputs; MultiDoFnFunction( @@ -61,12 +61,12 @@ class MultiDoFnFunction implements PairFlatMapFunction, TupleT this.mFunction = fn; this.mRuntimeContext = runtimeContext; this.mMainOutputTag = mainOutputTag; - this. mSideInputs = sideInputs; + this.mSideInputs = sideInputs; } @Override public Iterable, Object>> call(Iterator iter) throws Exception { - ProcCtxt ctxt = new ProcCtxt(mFunction); + ProcCtxt ctxt = new ProcCtxt(mFunction); mFunction.startBundle(ctxt); while (iter.hasNext()) { ctxt.element = iter.next(); @@ -82,7 +82,7 @@ public Tuple2, Object> apply(Map.Entry, Object> input) { }); } - private class ProcCtxt extends DoFn.ProcessContext { + private class ProcCtxt extends DoFn.ProcessContext { private final Multimap, Object> outputs = LinkedListMultimap.create(); private I element; @@ -98,7 +98,9 @@ public PipelineOptions getPipelineOptions() { @Override public T sideInput(PCollectionView view) { - return (T) mSideInputs.get(view.getTagInternal()).getValue(); + @SuppressWarnings("unchecked") + T value = (T) mSideInputs.get(view.getTagInternal()).getValue(); + return value; } @Override diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java index 0cdfa34072de3..f8b261a125ee3 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java @@ -107,8 +107,13 @@ public void leaveCompositeTransform(TransformTreeNode node) { @Override public void visitTransform(TransformTreeNode node) { - PTransform transform = node.getTransform(); - TransformEvaluator evaluator = TransformTranslator.getTransformEvaluator(transform.getClass()); + doVisitTransform(node.getTransform()); + } + + private void doVisitTransform(PT transform) { + @SuppressWarnings("unchecked") + TransformEvaluator evaluator = (TransformEvaluator) + TransformTranslator.getTransformEvaluator(transform.getClass()); LOG.info("Evaluating " + transform); evaluator.evaluate(transform, ctxt); } diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java index 82bf4491e9fea..5b1e085d9a673 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java @@ -42,7 +42,7 @@ class SparkRuntimeContext implements Serializable { /** * Map fo names to dataflow aggregators. */ - private final Map aggregators = new HashMap<>(); + private final Map> aggregators = new HashMap<>(); public SparkRuntimeContext(JavaSparkContext jsc, Pipeline pipeline) { this.accum = jsc.accumulator(new NamedAggregators(), new AggAccumParam()); @@ -77,12 +77,13 @@ public synchronized PipelineOptions getPipelineOptions() { public synchronized Aggregator createAggregator( String named, SerializableFunction, Out> sfunc) { - Aggregator aggregator = aggregators.get(named); + @SuppressWarnings("unchecked") + Aggregator aggregator = (Aggregator) aggregators.get(named); if (aggregator == null) { NamedAggregators.SerFunctionState state = new NamedAggregators .SerFunctionState<>(sfunc); accum.add(new NamedAggregators(named, state)); - aggregator = new SparkAggregator(state); + aggregator = new SparkAggregator<>(state); aggregators.put(named, aggregator); } return aggregator; @@ -100,12 +101,14 @@ public synchronized Aggregator createAggregator( public synchronized Aggregator createAggregator( String named, Combine.CombineFn combineFn) { - Aggregator aggregator = aggregators.get(named); + @SuppressWarnings("unchecked") + Aggregator aggregator = (Aggregator) aggregators.get(named); if (aggregator == null) { - NamedAggregators.CombineFunctionState state = new NamedAggregators - .CombineFunctionState<>(combineFn); + @SuppressWarnings("unchecked") + NamedAggregators.CombineFunctionState state = new NamedAggregators + .CombineFunctionState<>((Combine.CombineFn) combineFn); accum.add(new NamedAggregators(named, state)); - aggregator = new SparkAggregator(state); + aggregator = new SparkAggregator<>(state); aggregators.put(named, aggregator); } return aggregator; diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java index a23d509e95a45..477a8441bebeb 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java @@ -19,9 +19,23 @@ import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.io.AvroIO; import com.google.cloud.dataflow.sdk.io.TextIO; -import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Convert; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.CreatePObject; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SeqDo; import com.google.cloud.dataflow.sdk.util.WindowedValue; -import com.google.cloud.dataflow.sdk.values.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PObjectValueTuple; +import com.google.cloud.dataflow.sdk.values.TupleTag; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import org.apache.spark.api.java.JavaPairRDD; @@ -44,7 +58,7 @@ private TransformTranslator() { private static class FieldGetter { private final Map fields; - public FieldGetter(Class clazz) { + FieldGetter(Class clazz) { this.fields = Maps.newHashMap(); for (Field f : clazz.getDeclaredFields()) { f.setAccessible(true); @@ -54,202 +68,250 @@ public FieldGetter(Class clazz) { public T get(String fieldname, Object value) { try { - return (T) fields.get(fieldname).get(value); + @SuppressWarnings("unchecked") + T fieldValue = (T) fields.get(fieldname).get(value); + return fieldValue; } catch (IllegalAccessException e) { throw new IllegalStateException(e); } } } - private static final TransformEvaluator FLATTEN = new TransformEvaluator() { - @Override - public void evaluate(Flatten transform, EvaluationContext context) { - PCollectionList pcs = (PCollectionList) context.getPipeline().getInput(transform); - JavaRDD[] rdds = new JavaRDD[pcs.size()]; - for (int i = 0; i < rdds.length; i++) { - rdds[i] = (JavaRDD) context.getRDD(pcs.get(i)); + private static TransformEvaluator> flatten() { + return new TransformEvaluator>() { + @SuppressWarnings("unchecked") + @Override + public void evaluate(Flatten transform, EvaluationContext context) { + PCollectionList pcs = (PCollectionList) context.getPipeline().getInput(transform); + JavaRDD[] rdds = new JavaRDD[pcs.size()]; + for (int i = 0; i < rdds.length; i++) { + rdds[i] = (JavaRDD) context.getRDD(pcs.get(i)); + } + JavaRDD rdd = context.getSparkContext().union(rdds); + context.setOutputRDD(transform, rdd); } - JavaRDD rdd = context.getSparkContext().union(rdds); - context.setOutputRDD(transform, rdd); - } - }; - private static final TransformEvaluator GBK = new TransformEvaluator() { - @Override - public void evaluate(GroupByKey.GroupByKeyOnly transform, EvaluationContext context) { - context.setOutputRDD(transform, fromPair(toPair(context.getInputRDD(transform)).groupByKey())); - } - }; + }; + } + + private static TransformEvaluator> gbk() { + return new TransformEvaluator>() { + @Override + public void evaluate(GroupByKey.GroupByKeyOnly transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaRDDLike,?> inRDD = + (JavaRDDLike,?>) context.getInputRDD(transform); + context.setOutputRDD(transform, fromPair(toPair(inRDD).groupByKey())); + } + }; + } private static final FieldGetter GROUPED_FG = new FieldGetter(Combine.GroupedValues.class); - private static final TransformEvaluator GROUPED = new TransformEvaluator() { + private static TransformEvaluator> grouped() { + return new TransformEvaluator>() { + @Override + public void evaluate(Combine.GroupedValues transform, EvaluationContext context) { + Combine.KeyedCombineFn keyed = GROUPED_FG.get("fn", transform); + @SuppressWarnings("unchecked") + JavaRDDLike>,?> inRDD = + (JavaRDDLike>,?>) context.getInputRDD(transform); + context.setOutputRDD(transform, inRDD.map(new KVFunction<>(keyed))); + } + }; + } + + private static final class KVFunction implements Function>, KV> { + private final Combine.KeyedCombineFn keyed; + + KVFunction(Combine.KeyedCombineFn keyed) { + this.keyed = keyed; + } + @Override - public void evaluate(Combine.GroupedValues transform, EvaluationContext context) { - final Combine.KeyedCombineFn keyed = GROUPED_FG.get("fn", transform); - context.setOutputRDD(transform, context.getInputRDD(transform).map(new Function() { - @Override - public Object call(Object input) throws Exception { - KV kv = (KV) input; - return KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())); - } - })); + public KV call(KV> kv) throws Exception { + return KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())); } - }; + } - private static JavaPairRDD toPair(JavaRDDLike rdd) { - return rdd.mapToPair(new PairFunction() { + private static JavaPairRDD toPair(JavaRDDLike,?> rdd) { + return rdd.mapToPair(new PairFunction,K,V>() { @Override - public Tuple2 call(Object o) throws Exception { - KV kv = (KV) o; - return new Tuple2(kv.getKey(), kv.getValue()); + public Tuple2 call(KV kv) { + return new Tuple2<>(kv.getKey(), kv.getValue()); } }); } - private static JavaRDDLike fromPair(JavaPairRDD rdd) { - return rdd.map(new Function() { + private static JavaRDDLike,?> fromPair(JavaPairRDD rdd) { + return rdd.map(new Function,KV>() { @Override - public Object call(Object o) throws Exception { - Tuple2 t2 = (Tuple2) o; + public KV call(Tuple2 t2) { return KV.of(t2._1(), t2._2()); } }); } - - private static final TransformEvaluator PARDO = new TransformEvaluator() { - @Override - public void evaluate(ParDo.Bound transform, EvaluationContext context) { - DoFnFunction dofn = new DoFnFunction(transform.getFn(), - context.getRuntimeContext(), - getSideInputs(transform.getSideInputs(), context)); - context.setOutputRDD(transform, context.getInputRDD(transform).mapPartitions(dofn)); - } - }; + private static TransformEvaluator> parDo() { + return new TransformEvaluator>() { + @Override + public void evaluate(ParDo.Bound transform, EvaluationContext context) { + DoFnFunction dofn = + new DoFnFunction<>(transform.getFn(), + context.getRuntimeContext(), + getSideInputs(transform.getSideInputs(), context)); + @SuppressWarnings("unchecked") + JavaRDDLike inRDD = (JavaRDDLike) context.getInputRDD(transform); + context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + } + }; + } private static final FieldGetter MULTIDO_FG = new FieldGetter(ParDo.BoundMulti.class); - private static final TransformEvaluator MULTIDO = new TransformEvaluator() { - @Override - public void evaluate(ParDo.BoundMulti transform, EvaluationContext context) { - MultiDoFnFunction multifn = new MultiDoFnFunction( + private static TransformEvaluator> multiDo() { + return new TransformEvaluator>() { + @Override + public void evaluate(ParDo.BoundMulti transform, EvaluationContext context) { + TupleTag mainOutputTag = MULTIDO_FG.get("mainOutputTag", transform); + MultiDoFnFunction multifn = new MultiDoFnFunction<>( transform.getFn(), context.getRuntimeContext(), - (TupleTag) MULTIDO_FG.get("mainOutputTag", transform), + mainOutputTag, getSideInputs(transform.getSideInputs(), context)); - JavaPairRDD all = context.getInputRDD(transform) + @SuppressWarnings("unchecked") + JavaRDDLike inRDD = (JavaRDDLike) context.getInputRDD(transform); + JavaPairRDD, Object> all = inRDD .mapPartitionsToPair(multifn) .cache(); - PCollectionTuple pct = (PCollectionTuple) context.getOutput(transform); - for (Map.Entry, PCollection> e : pct.getAll().entrySet()) { - TupleTagFilter filter = new TupleTagFilter(e.getKey()); - JavaPairRDD filtered = all.filter(filter); - context.setRDD(e.getValue(), filtered.values()); + PCollectionTuple pct = context.getOutput(transform); + for (Map.Entry, PCollection> e : pct.getAll().entrySet()) { + @SuppressWarnings("unchecked") + JavaPairRDD, Object> filtered = + all.filter(new TupleTagFilter(e.getKey())); + context.setRDD(e.getValue(), filtered.values()); + } } - } - }; + }; + } - private static final TransformEvaluator READ_TEXT = new TransformEvaluator() { - @Override - public void evaluate(TextIO.Read.Bound transform, EvaluationContext context) { - String pattern = transform.getFilepattern(); - JavaRDD rdd = context.getSparkContext().textFile(pattern); - context.setOutputRDD(transform, rdd); - } - }; + private static TransformEvaluator> readText() { + return new TransformEvaluator>() { + @Override + public void evaluate(TextIO.Read.Bound transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaRDD rdd = context.getSparkContext().textFile(pattern); + context.setOutputRDD(transform, rdd); + } + }; + } - private static final TransformEvaluator WRITE_TEXT = new TransformEvaluator() { - @Override - public void evaluate(TextIO.Write.Bound transform, EvaluationContext context) { - JavaRDDLike last = context.getInputRDD(transform); - String pattern = transform.getFilenamePrefix(); - last.saveAsTextFile(pattern); - } - }; + private static TransformEvaluator> writeText() { + return new TransformEvaluator>() { + @Override + public void evaluate(TextIO.Write.Bound transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaRDDLike last = (JavaRDDLike) context.getInputRDD(transform); + String pattern = transform.getFilenamePrefix(); + last.saveAsTextFile(pattern); + } + }; + } - private static final TransformEvaluator READ_AVRO = new TransformEvaluator() { - @Override - public void evaluate(AvroIO.Read.Bound transform, EvaluationContext context) { - String pattern = transform.getFilepattern(); - JavaRDD rdd = context.getSparkContext().textFile(pattern); - context.setOutputRDD(transform, rdd); - } - }; + private static TransformEvaluator> readAvro() { + return new TransformEvaluator>() { + @Override + public void evaluate(AvroIO.Read.Bound transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaRDD rdd = context.getSparkContext().textFile(pattern); + context.setOutputRDD(transform, rdd); + } + }; + } - private static final TransformEvaluator WRITE_AVRO = new TransformEvaluator() { - @Override - public void evaluate(AvroIO.Write.Bound transform, EvaluationContext context) { - JavaRDDLike last = context.getInputRDD(transform); - String pattern = transform.getFilenamePrefix(); - last.saveAsTextFile(pattern); - } - }; + private static TransformEvaluator> writeAvro() { + return new TransformEvaluator>() { + @Override + public void evaluate(AvroIO.Write.Bound transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaRDDLike last = (JavaRDDLike) context.getInputRDD(transform); + String pattern = transform.getFilenamePrefix(); + last.saveAsTextFile(pattern); + } + }; + } - private static final TransformEvaluator CREATE = new TransformEvaluator() { - @Override - public void evaluate(Create transform, EvaluationContext context) { - Iterable elems = transform.getElements(); - Coder coder = ((PCollection) context.getOutput(transform)).getCoder(); - JavaRDD rdd = context.getSparkContext().parallelize( + private static TransformEvaluator> create() { + return new TransformEvaluator>() { + @Override + public void evaluate(Create transform, EvaluationContext context) { + Iterable elems = transform.getElements(); + Coder coder = context.getOutput(transform).getCoder(); + JavaRDD rdd = context.getSparkContext().parallelize( CoderHelpers.toByteArrays(elems, coder)); - context.setOutputRDD(transform, rdd.map(CoderHelpers.fromByteFunction(coder))); - } - }; + context.setOutputRDD(transform, rdd.map(CoderHelpers.fromByteFunction(coder))); + } + }; + } - private static final TransformEvaluator CREATE_POBJ = new TransformEvaluator() { - @Override - public void evaluate(CreatePObject transform, EvaluationContext context) { - context.setPObjectValue((PObject) context.getOutput(transform), transform.getElement()); - } - }; + private static TransformEvaluator> createPObj() { + return new TransformEvaluator>() { + @Override + public void evaluate(CreatePObject transform, EvaluationContext context) { + context.setPObjectValue(context.getOutput(transform), transform.getElement()); + } + }; + } - private static final TransformEvaluator TO_ITER = new TransformEvaluator() { - @Override - public void evaluate(Convert.ToIterable transform, EvaluationContext context) { - PCollection in = (PCollection) context.getInput(transform); - PObject out = (PObject) context.getOutput(transform); - context.setPObjectValue(out, context.get(in)); - } - }; - - private static final TransformEvaluator TO_ITER_WIN = - new TransformEvaluator() { - @Override - public void evaluate(Convert.ToIterableWindowedValue transform, EvaluationContext context) { - PCollection in = (PCollection) context.getInput(transform); - PObject out = (PObject) context.getOutput(transform); - context.setPObjectValue(out, Iterables.transform(context.get(in), - new com.google.common.base.Function() { - @Override - public WindowedValue apply(Object o) { - return WindowedValue.valueInGlobalWindow(o); - } - })); - } - }; + private static TransformEvaluator> toIter() { + return new TransformEvaluator>() { + @Override + public void evaluate(Convert.ToIterable transform, EvaluationContext context) { + context.setPObjectValue(context.getOutput(transform), + context.get(context.getInput(transform))); + } + }; + } - private static class TupleTagFilter implements Function, Boolean> { - private final TupleTag tag; + private static TransformEvaluator> toIterWin() { + return new TransformEvaluator>() { + @Override + public void evaluate(Convert.ToIterableWindowedValue transform, EvaluationContext context) { + context.setPObjectValue(context.getOutput(transform), + Iterables.transform(context.get(context.getInput(transform)), + new com.google.common.base.Function>() { + @Override + public WindowedValue apply(T t) { + return WindowedValue.valueInGlobalWindow(t); + } + })); + } + }; + } - public TupleTagFilter(TupleTag tag) { + private static class TupleTagFilter implements Function, Object>, Boolean> { + private final TupleTag tag; + + private TupleTagFilter(TupleTag tag) { this.tag = tag; } @Override - public Boolean call(Tuple2 input) throws Exception { + public Boolean call(Tuple2, Object> input) { return tag.equals(input._1()); } } - private static final TransformEvaluator SEQDO = new TransformEvaluator() { - @Override - public void evaluate(SeqDo.BoundMulti transform, EvaluationContext context) { - PObjectValueTuple inputValues = context.getPObjectTuple(transform); - PObjectValueTuple outputValues = transform.getFn().process(inputValues); - context.setPObjectTuple(transform, outputValues); - } - }; + private static TransformEvaluator seqDo() { + return new TransformEvaluator() { + @Override + public void evaluate(SeqDo.BoundMulti transform, EvaluationContext context) { + PObjectValueTuple inputValues = context.getPObjectTuple(transform); + PObjectValueTuple outputValues = transform.getFn().process(inputValues); + context.setPObjectTuple(transform, outputValues); + } + }; + } private static Map, BroadcastHelper> getSideInputs( Iterable> views, @@ -265,26 +327,27 @@ private static Map, BroadcastHelper> getSideInputs( } } - private static final Map, TransformEvaluator> mEvaluators = Maps.newHashMap(); + private static final Map, TransformEvaluator> mEvaluators = Maps.newHashMap(); static { - mEvaluators.put(TextIO.Read.Bound.class, READ_TEXT); - mEvaluators.put(TextIO.Write.Bound.class, WRITE_TEXT); - mEvaluators.put(AvroIO.Read.Bound.class, READ_AVRO); - mEvaluators.put(AvroIO.Write.Bound.class, WRITE_AVRO); - mEvaluators.put(ParDo.Bound.class, PARDO); - mEvaluators.put(ParDo.BoundMulti.class, MULTIDO); - mEvaluators.put(SeqDo.BoundMulti.class, SEQDO); - mEvaluators.put(GroupByKey.GroupByKeyOnly.class, GBK); - mEvaluators.put(Combine.GroupedValues.class, GROUPED); - mEvaluators.put(Flatten.class, FLATTEN); - mEvaluators.put(Create.class, CREATE); - mEvaluators.put(CreatePObject.class, CREATE_POBJ); - mEvaluators.put(Convert.ToIterable.class, TO_ITER); - mEvaluators.put(Convert.ToIterableWindowedValue.class, TO_ITER_WIN); + mEvaluators.put(TextIO.Read.Bound.class, readText()); + mEvaluators.put(TextIO.Write.Bound.class, writeText()); + mEvaluators.put(AvroIO.Read.Bound.class, readAvro()); + mEvaluators.put(AvroIO.Write.Bound.class, writeAvro()); + mEvaluators.put(ParDo.Bound.class, parDo()); + mEvaluators.put(ParDo.BoundMulti.class, multiDo()); + mEvaluators.put(SeqDo.BoundMulti.class, seqDo()); + mEvaluators.put(GroupByKey.GroupByKeyOnly.class, gbk()); + mEvaluators.put(Combine.GroupedValues.class, grouped()); + mEvaluators.put(Flatten.class, flatten()); + mEvaluators.put(Create.class, create()); + mEvaluators.put(CreatePObject.class, createPObj()); + mEvaluators.put(Convert.ToIterable.class, toIter()); + mEvaluators.put(Convert.ToIterableWindowedValue.class, toIterWin()); } - public static TransformEvaluator getTransformEvaluator(Class clazz) { - TransformEvaluator transform = mEvaluators.get(clazz); + public static TransformEvaluator getTransformEvaluator(Class clazz) { + @SuppressWarnings("unchecked") + TransformEvaluator transform = (TransformEvaluator) mEvaluators.get(clazz); if (transform == null) { throw new IllegalStateException("No TransformEvaluator registered for " + clazz); } diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/aggregators/NamedAggregators.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/aggregators/NamedAggregators.java index e372b951a7c1d..06ee4d37c06e0 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/aggregators/NamedAggregators.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/aggregators/NamedAggregators.java @@ -171,13 +171,19 @@ public SerFunctionState(SerializableFunction, Out> sfunc) { @Override public void update(In element) { - this.state = sfunc.apply(ImmutableList.of(element, (In) state)); + @SuppressWarnings("unchecked") + In thisState = (In) state; + this.state = sfunc.apply(ImmutableList.of(element, thisState)); } @Override - public State merge(State other) { + public State merge(State other) { // Add exception catching and logging here. - this.state = sfunc.apply(ImmutableList.of((In) state, (In) other.current())); + @SuppressWarnings("unchecked") + In thisState = (In) state; + @SuppressWarnings("unchecked") + In otherCurrent = (In) other.current(); + this.state = sfunc.apply(ImmutableList.of(thisState, otherCurrent)); return this; }