diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
index 3d4afdf938f09..5966c64fd725d 100644
--- a/runners/spark/pom.xml
+++ b/runners/spark/pom.xml
@@ -32,6 +32,9 @@ License.
1.7
+
+ -Xlint:all,-serial
+
diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
index 8cd593979ae87..673237e574e84 100644
--- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
+++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
@@ -64,7 +64,7 @@ public DoFnFunction(
@Override
public Iterable call(Iterator iter) throws Exception {
- ProcCtxt ctxt = new ProcCtxt<>(mFunction);
+ ProcCtxt ctxt = new ProcCtxt(mFunction);
//setup
mFunction.startBundle(ctxt);
//operation
@@ -77,7 +77,7 @@ public Iterable call(Iterator iter) throws Exception {
return ctxt.outputs;
}
- private class ProcCtxt extends DoFn.ProcessContext {
+ private class ProcCtxt extends DoFn.ProcessContext {
private final List outputs = new LinkedList<>();
private I element;
@@ -93,7 +93,9 @@ public PipelineOptions getPipelineOptions() {
@Override
public T sideInput(PCollectionView view) {
- return (T) mSideInputs.get(view.getTagInternal()).getValue();
+ @SuppressWarnings("unchecked")
+ T value = (T) mSideInputs.get(view.getTagInternal()).getValue();
+ return value;
}
@Override
diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
index f992cf59a3e8f..4105ab61eb1f3 100644
--- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
+++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
@@ -45,9 +45,9 @@ public class EvaluationContext implements EvaluationResult {
private final JavaSparkContext jsc;
private final Pipeline pipeline;
private final SparkRuntimeContext runtime;
- private final Map rdds = Maps.newHashMap();
+ private final Map> rdds = Maps.newHashMap();
private final Set multireads = Sets.newHashSet();
- private final Map pobjects = Maps.newHashMap();
+ private final Map, Object> pobjects = Maps.newHashMap();
public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) {
this.jsc = jsc;
@@ -68,23 +68,27 @@ SparkRuntimeContext getRuntimeContext() {
}
I getInput(PTransform transform) {
- return (I) pipeline.getInput(transform);
+ @SuppressWarnings("unchecked")
+ I input = (I) pipeline.getInput(transform);
+ return input;
}
O getOutput(PTransform, O> transform) {
- return (O) pipeline.getOutput(transform);
+ @SuppressWarnings("unchecked")
+ O output = (O) pipeline.getOutput(transform);
+ return output;
}
- void setOutputRDD(PTransform transform, JavaRDDLike rdd) {
+ void setOutputRDD(PTransform,?> transform, JavaRDDLike,?> rdd) {
rdds.put((PValue) getOutput(transform), rdd);
}
- void setPObjectValue(PObject pobject, Object value) {
+ void setPObjectValue(PObject> pobject, Object value) {
pobjects.put(pobject, value);
}
- JavaRDDLike getRDD(PValue pvalue) {
- JavaRDDLike rdd = rdds.get(pvalue);
+ JavaRDDLike,?> getRDD(PValue pvalue) {
+ JavaRDDLike,?> rdd = rdds.get(pvalue);
if (multireads.contains(pvalue)) {
// Ensure the RDD is marked as cached
rdd.rdd().cache();
@@ -94,11 +98,11 @@ JavaRDDLike getRDD(PValue pvalue) {
return rdd;
}
- void setRDD(PValue pvalue, JavaRDDLike rdd) {
+ void setRDD(PValue pvalue, JavaRDDLike,?> rdd) {
rdds.put(pvalue, rdd);
}
- JavaRDDLike getInputRDD(PTransform transform) {
+ JavaRDDLike,?> getInputRDD(PTransform transform) {
return getRDD((PValue) pipeline.getInput(transform));
}
@@ -111,11 +115,14 @@ BroadcastHelper getBroadcastHelper(PObject value) {
@Override
public T get(PObject value) {
if (pobjects.containsKey(value)) {
- return (T) pobjects.get(value);
+ @SuppressWarnings("unchecked")
+ T result = (T) pobjects.get(value);
+ return result;
}
if (rdds.containsKey(value)) {
- JavaRDDLike rdd = rdds.get(value);
+ JavaRDDLike,?> rdd = rdds.get(value);
//TODO: need same logic from get() method below here for serialization of bytes
+ @SuppressWarnings("unchecked")
T res = (T) Iterables.getOnlyElement(rdd.collect());
pobjects.put(value, res);
return res;
@@ -130,10 +137,11 @@ public T getAggregatorValue(String named, Class resultType) {
@Override
public Iterable get(PCollection pcollection) {
- JavaRDDLike rdd = getRDD(pcollection);
- final Coder coder = pcollection.getCoder();
- JavaRDDLike bytes = rdd.map(CoderHelpers.toByteFunction(coder));
- List clientBytes = bytes.collect();
+ @SuppressWarnings("unchecked")
+ JavaRDDLike rdd = (JavaRDDLike) getRDD(pcollection);
+ final Coder coder = pcollection.getCoder();
+ JavaRDDLike bytesRDD = rdd.map(CoderHelpers.toByteFunction(coder));
+ List clientBytes = bytesRDD.collect();
return Iterables.transform(clientBytes, new Function() {
@Override
public T apply(byte[] bytes) {
@@ -142,16 +150,24 @@ public T apply(byte[] bytes) {
});
}
- PObjectValueTuple getPObjectTuple(PTransform transform) {
+ PObjectValueTuple getPObjectTuple(PTransform,?> transform) {
PObjectTuple pot = (PObjectTuple) pipeline.getInput(transform);
PObjectValueTuple povt = PObjectValueTuple.empty();
for (Map.Entry, PObject>> e : pot.getAll().entrySet()) {
- povt = povt.and((TupleTag) e.getKey(), get(e.getValue()));
+ povt = and(povt, e);
}
return povt;
}
- void setPObjectTuple(PTransform transform, PObjectValueTuple outputValues) {
+ private PObjectValueTuple and(PObjectValueTuple povt, Map.Entry, PObject>> e) {
+ @SuppressWarnings("unchecked")
+ TupleTag ttKey = (TupleTag) e.getKey();
+ @SuppressWarnings("unchecked")
+ PObject potValue = (PObject) e.getValue();
+ return povt.and(ttKey, get(potValue));
+ }
+
+ void setPObjectTuple(PTransform,?> transform, PObjectValueTuple outputValues) {
PObjectTuple pot = (PObjectTuple) pipeline.getOutput(transform);
for (Map.Entry, PObject>> e : pot.getAll().entrySet()) {
pobjects.put(e.getValue(), outputValues.get(e.getKey()));
diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java
index 922ae04aed57b..70facf82bb4d4 100644
--- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/MultiDoFnFunction.java
@@ -50,7 +50,7 @@ class MultiDoFnFunction implements PairFlatMapFunction, TupleT
private final DoFn mFunction;
private final SparkRuntimeContext mRuntimeContext;
- private final TupleTag> mMainOutputTag;
+ private final TupleTag mMainOutputTag;
private final Map, BroadcastHelper>> mSideInputs;
MultiDoFnFunction(
@@ -61,12 +61,12 @@ class MultiDoFnFunction implements PairFlatMapFunction, TupleT
this.mFunction = fn;
this.mRuntimeContext = runtimeContext;
this.mMainOutputTag = mainOutputTag;
- this. mSideInputs = sideInputs;
+ this.mSideInputs = sideInputs;
}
@Override
public Iterable, Object>> call(Iterator iter) throws Exception {
- ProcCtxt ctxt = new ProcCtxt(mFunction);
+ ProcCtxt ctxt = new ProcCtxt(mFunction);
mFunction.startBundle(ctxt);
while (iter.hasNext()) {
ctxt.element = iter.next();
@@ -82,7 +82,7 @@ public Tuple2, Object> apply(Map.Entry, Object> input) {
});
}
- private class ProcCtxt extends DoFn.ProcessContext {
+ private class ProcCtxt extends DoFn.ProcessContext {
private final Multimap, Object> outputs = LinkedListMultimap.create();
private I element;
@@ -98,7 +98,9 @@ public PipelineOptions getPipelineOptions() {
@Override
public T sideInput(PCollectionView view) {
- return (T) mSideInputs.get(view.getTagInternal()).getValue();
+ @SuppressWarnings("unchecked")
+ T value = (T) mSideInputs.get(view.getTagInternal()).getValue();
+ return value;
}
@Override
diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java
index 0cdfa34072de3..f8b261a125ee3 100644
--- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java
+++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java
@@ -107,8 +107,13 @@ public void leaveCompositeTransform(TransformTreeNode node) {
@Override
public void visitTransform(TransformTreeNode node) {
- PTransform, ?> transform = node.getTransform();
- TransformEvaluator evaluator = TransformTranslator.getTransformEvaluator(transform.getClass());
+ doVisitTransform(node.getTransform());
+ }
+
+ private void doVisitTransform(PT transform) {
+ @SuppressWarnings("unchecked")
+ TransformEvaluator evaluator = (TransformEvaluator)
+ TransformTranslator.getTransformEvaluator(transform.getClass());
LOG.info("Evaluating " + transform);
evaluator.evaluate(transform, ctxt);
}
diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
index 82bf4491e9fea..5b1e085d9a673 100644
--- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
+++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
@@ -42,7 +42,7 @@ class SparkRuntimeContext implements Serializable {
/**
* Map fo names to dataflow aggregators.
*/
- private final Map aggregators = new HashMap<>();
+ private final Map> aggregators = new HashMap<>();
public SparkRuntimeContext(JavaSparkContext jsc, Pipeline pipeline) {
this.accum = jsc.accumulator(new NamedAggregators(), new AggAccumParam());
@@ -77,12 +77,13 @@ public synchronized PipelineOptions getPipelineOptions() {
public synchronized Aggregator createAggregator(
String named,
SerializableFunction, Out> sfunc) {
- Aggregator aggregator = aggregators.get(named);
+ @SuppressWarnings("unchecked")
+ Aggregator aggregator = (Aggregator) aggregators.get(named);
if (aggregator == null) {
NamedAggregators.SerFunctionState state = new NamedAggregators
.SerFunctionState<>(sfunc);
accum.add(new NamedAggregators(named, state));
- aggregator = new SparkAggregator(state);
+ aggregator = new SparkAggregator<>(state);
aggregators.put(named, aggregator);
}
return aggregator;
@@ -100,12 +101,14 @@ public synchronized Aggregator createAggregator(
public synchronized Aggregator createAggregator(
String named,
Combine.CombineFn super In, Inter, Out> combineFn) {
- Aggregator aggregator = aggregators.get(named);
+ @SuppressWarnings("unchecked")
+ Aggregator aggregator = (Aggregator) aggregators.get(named);
if (aggregator == null) {
- NamedAggregators.CombineFunctionState super In, Inter, Out> state = new NamedAggregators
- .CombineFunctionState<>(combineFn);
+ @SuppressWarnings("unchecked")
+ NamedAggregators.CombineFunctionState state = new NamedAggregators
+ .CombineFunctionState<>((Combine.CombineFn) combineFn);
accum.add(new NamedAggregators(named, state));
- aggregator = new SparkAggregator(state);
+ aggregator = new SparkAggregator<>(state);
aggregators.put(named, aggregator);
}
return aggregator;
diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
index a23d509e95a45..477a8441bebeb 100644
--- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
+++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
@@ -19,9 +19,23 @@
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.io.AvroIO;
import com.google.cloud.dataflow.sdk.io.TextIO;
-import com.google.cloud.dataflow.sdk.transforms.*;
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import com.google.cloud.dataflow.sdk.transforms.Convert;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.CreatePObject;
+import com.google.cloud.dataflow.sdk.transforms.Flatten;
+import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.transforms.SeqDo;
import com.google.cloud.dataflow.sdk.util.WindowedValue;
-import com.google.cloud.dataflow.sdk.values.*;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionList;
+import com.google.cloud.dataflow.sdk.values.PCollectionTuple;
+import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.PObjectValueTuple;
+import com.google.cloud.dataflow.sdk.values.TupleTag;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import org.apache.spark.api.java.JavaPairRDD;
@@ -44,7 +58,7 @@ private TransformTranslator() {
private static class FieldGetter {
private final Map fields;
- public FieldGetter(Class> clazz) {
+ FieldGetter(Class> clazz) {
this.fields = Maps.newHashMap();
for (Field f : clazz.getDeclaredFields()) {
f.setAccessible(true);
@@ -54,202 +68,250 @@ public FieldGetter(Class> clazz) {
public T get(String fieldname, Object value) {
try {
- return (T) fields.get(fieldname).get(value);
+ @SuppressWarnings("unchecked")
+ T fieldValue = (T) fields.get(fieldname).get(value);
+ return fieldValue;
} catch (IllegalAccessException e) {
throw new IllegalStateException(e);
}
}
}
- private static final TransformEvaluator FLATTEN = new TransformEvaluator() {
- @Override
- public void evaluate(Flatten transform, EvaluationContext context) {
- PCollectionList> pcs = (PCollectionList>) context.getPipeline().getInput(transform);
- JavaRDD[] rdds = new JavaRDD[pcs.size()];
- for (int i = 0; i < rdds.length; i++) {
- rdds[i] = (JavaRDD) context.getRDD(pcs.get(i));
+ private static TransformEvaluator> flatten() {
+ return new TransformEvaluator>() {
+ @SuppressWarnings("unchecked")
+ @Override
+ public void evaluate(Flatten transform, EvaluationContext context) {
+ PCollectionList pcs = (PCollectionList) context.getPipeline().getInput(transform);
+ JavaRDD[] rdds = new JavaRDD[pcs.size()];
+ for (int i = 0; i < rdds.length; i++) {
+ rdds[i] = (JavaRDD) context.getRDD(pcs.get(i));
+ }
+ JavaRDD rdd = context.getSparkContext().union(rdds);
+ context.setOutputRDD(transform, rdd);
}
- JavaRDD rdd = context.getSparkContext().union(rdds);
- context.setOutputRDD(transform, rdd);
- }
- };
- private static final TransformEvaluator GBK = new TransformEvaluator() {
- @Override
- public void evaluate(GroupByKey.GroupByKeyOnly transform, EvaluationContext context) {
- context.setOutputRDD(transform, fromPair(toPair(context.getInputRDD(transform)).groupByKey()));
- }
- };
+ };
+ }
+
+ private static TransformEvaluator> gbk() {
+ return new TransformEvaluator>() {
+ @Override
+ public void evaluate(GroupByKey.GroupByKeyOnly transform, EvaluationContext context) {
+ @SuppressWarnings("unchecked")
+ JavaRDDLike,?> inRDD =
+ (JavaRDDLike,?>) context.getInputRDD(transform);
+ context.setOutputRDD(transform, fromPair(toPair(inRDD).groupByKey()));
+ }
+ };
+ }
private static final FieldGetter GROUPED_FG = new FieldGetter(Combine.GroupedValues.class);
- private static final TransformEvaluator GROUPED = new TransformEvaluator() {
+ private static TransformEvaluator> grouped() {
+ return new TransformEvaluator>() {
+ @Override
+ public void evaluate(Combine.GroupedValues transform, EvaluationContext context) {
+ Combine.KeyedCombineFn keyed = GROUPED_FG.get("fn", transform);
+ @SuppressWarnings("unchecked")
+ JavaRDDLike>,?> inRDD =
+ (JavaRDDLike>,?>) context.getInputRDD(transform);
+ context.setOutputRDD(transform, inRDD.map(new KVFunction<>(keyed)));
+ }
+ };
+ }
+
+ private static final class KVFunction implements Function>, KV> {
+ private final Combine.KeyedCombineFn keyed;
+
+ KVFunction(Combine.KeyedCombineFn keyed) {
+ this.keyed = keyed;
+ }
+
@Override
- public void evaluate(Combine.GroupedValues transform, EvaluationContext context) {
- final Combine.KeyedCombineFn keyed = GROUPED_FG.get("fn", transform);
- context.setOutputRDD(transform, context.getInputRDD(transform).map(new Function() {
- @Override
- public Object call(Object input) throws Exception {
- KV