Skip to content

Commit

Permalink
Issue apache#13 : attempt to remove all generics warnings, or handle …
Browse files Browse the repository at this point in the history
…them explicitly
  • Loading branch information
srowen authored and peihe committed Mar 14, 2016
1 parent 4fece22 commit 6b62463
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 203 deletions.
3 changes: 3 additions & 0 deletions runners/spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ License.
<configuration>
<source>1.7</source>
<target>1.7</target>
<compilerArgs>
<arg>-Xlint:all,-serial</arg>
</compilerArgs>
</configuration>
</plugin>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public DoFnFunction(

@Override
public Iterable<O> call(Iterator<I> iter) throws Exception {
ProcCtxt<I, O> ctxt = new ProcCtxt<>(mFunction);
ProcCtxt ctxt = new ProcCtxt(mFunction);
//setup
mFunction.startBundle(ctxt);
//operation
Expand All @@ -77,7 +77,7 @@ public Iterable<O> call(Iterator<I> iter) throws Exception {
return ctxt.outputs;
}

private class ProcCtxt<I, O> extends DoFn<I, O>.ProcessContext {
private class ProcCtxt extends DoFn<I, O>.ProcessContext {

private final List<O> outputs = new LinkedList<>();
private I element;
Expand All @@ -93,7 +93,9 @@ public PipelineOptions getPipelineOptions() {

@Override
public <T> T sideInput(PCollectionView<T, ?> view) {
return (T) mSideInputs.get(view.getTagInternal()).getValue();
@SuppressWarnings("unchecked")
T value = (T) mSideInputs.get(view.getTagInternal()).getValue();
return value;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ public class EvaluationContext implements EvaluationResult {
private final JavaSparkContext jsc;
private final Pipeline pipeline;
private final SparkRuntimeContext runtime;
private final Map<PValue, JavaRDDLike> rdds = Maps.newHashMap();
private final Map<PValue, JavaRDDLike<?,?>> rdds = Maps.newHashMap();
private final Set<PValue> multireads = Sets.newHashSet();
private final Map<PObject, Object> pobjects = Maps.newHashMap();
private final Map<PObject<?>, Object> pobjects = Maps.newHashMap();

public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) {
this.jsc = jsc;
Expand All @@ -68,23 +68,27 @@ SparkRuntimeContext getRuntimeContext() {
}

<I extends PInput> I getInput(PTransform<I, ?> transform) {
return (I) pipeline.getInput(transform);
@SuppressWarnings("unchecked")
I input = (I) pipeline.getInput(transform);
return input;
}

<O extends POutput> O getOutput(PTransform<?, O> transform) {
return (O) pipeline.getOutput(transform);
@SuppressWarnings("unchecked")
O output = (O) pipeline.getOutput(transform);
return output;
}

void setOutputRDD(PTransform transform, JavaRDDLike rdd) {
void setOutputRDD(PTransform<?,?> transform, JavaRDDLike<?,?> rdd) {
rdds.put((PValue) getOutput(transform), rdd);
}

void setPObjectValue(PObject pobject, Object value) {
void setPObjectValue(PObject<?> pobject, Object value) {
pobjects.put(pobject, value);
}

JavaRDDLike getRDD(PValue pvalue) {
JavaRDDLike rdd = rdds.get(pvalue);
JavaRDDLike<?,?> getRDD(PValue pvalue) {
JavaRDDLike<?,?> rdd = rdds.get(pvalue);
if (multireads.contains(pvalue)) {
// Ensure the RDD is marked as cached
rdd.rdd().cache();
Expand All @@ -94,11 +98,11 @@ JavaRDDLike getRDD(PValue pvalue) {
return rdd;
}

void setRDD(PValue pvalue, JavaRDDLike rdd) {
void setRDD(PValue pvalue, JavaRDDLike<?,?> rdd) {
rdds.put(pvalue, rdd);
}

JavaRDDLike getInputRDD(PTransform transform) {
JavaRDDLike<?,?> getInputRDD(PTransform transform) {
return getRDD((PValue) pipeline.getInput(transform));
}

Expand All @@ -111,11 +115,14 @@ <T> BroadcastHelper<T> getBroadcastHelper(PObject<T> value) {
@Override
public <T> T get(PObject<T> value) {
if (pobjects.containsKey(value)) {
return (T) pobjects.get(value);
@SuppressWarnings("unchecked")
T result = (T) pobjects.get(value);
return result;
}
if (rdds.containsKey(value)) {
JavaRDDLike rdd = rdds.get(value);
JavaRDDLike<?,?> rdd = rdds.get(value);
//TODO: need same logic from get() method below here for serialization of bytes
@SuppressWarnings("unchecked")
T res = (T) Iterables.getOnlyElement(rdd.collect());
pobjects.put(value, res);
return res;
Expand All @@ -130,10 +137,11 @@ public <T> T getAggregatorValue(String named, Class<T> resultType) {

@Override
public <T> Iterable<T> get(PCollection<T> pcollection) {
JavaRDDLike rdd = getRDD(pcollection);
final Coder coder = pcollection.getCoder();
JavaRDDLike bytes = rdd.map(CoderHelpers.toByteFunction(coder));
List clientBytes = bytes.collect();
@SuppressWarnings("unchecked")
JavaRDDLike<T,?> rdd = (JavaRDDLike<T,?>) getRDD(pcollection);
final Coder<T> coder = pcollection.getCoder();
JavaRDDLike<byte[],?> bytesRDD = rdd.map(CoderHelpers.toByteFunction(coder));
List<byte[]> clientBytes = bytesRDD.collect();
return Iterables.transform(clientBytes, new Function<byte[], T>() {
@Override
public T apply(byte[] bytes) {
Expand All @@ -142,16 +150,24 @@ public T apply(byte[] bytes) {
});
}

PObjectValueTuple getPObjectTuple(PTransform transform) {
PObjectValueTuple getPObjectTuple(PTransform<?,?> transform) {
PObjectTuple pot = (PObjectTuple) pipeline.getInput(transform);
PObjectValueTuple povt = PObjectValueTuple.empty();
for (Map.Entry<TupleTag<?>, PObject<?>> e : pot.getAll().entrySet()) {
povt = povt.and((TupleTag) e.getKey(), get(e.getValue()));
povt = and(povt, e);
}
return povt;
}

void setPObjectTuple(PTransform transform, PObjectValueTuple outputValues) {
private <T> PObjectValueTuple and(PObjectValueTuple povt, Map.Entry<TupleTag<?>, PObject<?>> e) {
@SuppressWarnings("unchecked")
TupleTag<T> ttKey = (TupleTag<T>) e.getKey();
@SuppressWarnings("unchecked")
PObject<T> potValue = (PObject<T>) e.getValue();
return povt.and(ttKey, get(potValue));
}

void setPObjectTuple(PTransform<?,?> transform, PObjectValueTuple outputValues) {
PObjectTuple pot = (PObjectTuple) pipeline.getOutput(transform);
for (Map.Entry<TupleTag<?>, PObject<?>> e : pot.getAll().entrySet()) {
pobjects.put(e.getValue(), outputValues.get(e.getKey()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class MultiDoFnFunction<I, O> implements PairFlatMapFunction<Iterator<I>, TupleT

private final DoFn<I, O> mFunction;
private final SparkRuntimeContext mRuntimeContext;
private final TupleTag<?> mMainOutputTag;
private final TupleTag<O> mMainOutputTag;
private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs;

MultiDoFnFunction(
Expand All @@ -61,12 +61,12 @@ class MultiDoFnFunction<I, O> implements PairFlatMapFunction<Iterator<I>, TupleT
this.mFunction = fn;
this.mRuntimeContext = runtimeContext;
this.mMainOutputTag = mainOutputTag;
this. mSideInputs = sideInputs;
this.mSideInputs = sideInputs;
}

@Override
public Iterable<Tuple2<TupleTag<?>, Object>> call(Iterator<I> iter) throws Exception {
ProcCtxt<I, O> ctxt = new ProcCtxt(mFunction);
ProcCtxt ctxt = new ProcCtxt(mFunction);
mFunction.startBundle(ctxt);
while (iter.hasNext()) {
ctxt.element = iter.next();
Expand All @@ -82,7 +82,7 @@ public Tuple2<TupleTag<?>, Object> apply(Map.Entry<TupleTag<?>, Object> input) {
});
}

private class ProcCtxt<I, O> extends DoFn<I, O>.ProcessContext {
private class ProcCtxt extends DoFn<I, O>.ProcessContext {

private final Multimap<TupleTag<?>, Object> outputs = LinkedListMultimap.create();
private I element;
Expand All @@ -98,7 +98,9 @@ public PipelineOptions getPipelineOptions() {

@Override
public <T> T sideInput(PCollectionView<T, ?> view) {
return (T) mSideInputs.get(view.getTagInternal()).getValue();
@SuppressWarnings("unchecked")
T value = (T) mSideInputs.get(view.getTagInternal()).getValue();
return value;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ public void leaveCompositeTransform(TransformTreeNode node) {

@Override
public void visitTransform(TransformTreeNode node) {
PTransform<?, ?> transform = node.getTransform();
TransformEvaluator evaluator = TransformTranslator.getTransformEvaluator(transform.getClass());
doVisitTransform(node.getTransform());
}

private <PT extends PTransform> void doVisitTransform(PT transform) {
@SuppressWarnings("unchecked")
TransformEvaluator<PT> evaluator = (TransformEvaluator<PT>)
TransformTranslator.getTransformEvaluator(transform.getClass());
LOG.info("Evaluating " + transform);
evaluator.evaluate(transform, ctxt);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SparkRuntimeContext implements Serializable {
/**
* Map fo names to dataflow aggregators.
*/
private final Map<String, Aggregator> aggregators = new HashMap<>();
private final Map<String, Aggregator<?>> aggregators = new HashMap<>();

public SparkRuntimeContext(JavaSparkContext jsc, Pipeline pipeline) {
this.accum = jsc.accumulator(new NamedAggregators(), new AggAccumParam());
Expand Down Expand Up @@ -77,12 +77,13 @@ public synchronized PipelineOptions getPipelineOptions() {
public synchronized <In, Out> Aggregator<In> createAggregator(
String named,
SerializableFunction<Iterable<In>, Out> sfunc) {
Aggregator aggregator = aggregators.get(named);
@SuppressWarnings("unchecked")
Aggregator<In> aggregator = (Aggregator<In>) aggregators.get(named);
if (aggregator == null) {
NamedAggregators.SerFunctionState<In, Out> state = new NamedAggregators
.SerFunctionState<>(sfunc);
accum.add(new NamedAggregators(named, state));
aggregator = new SparkAggregator(state);
aggregator = new SparkAggregator<>(state);
aggregators.put(named, aggregator);
}
return aggregator;
Expand All @@ -100,12 +101,14 @@ public synchronized <In, Out> Aggregator<In> createAggregator(
public synchronized <In, Inter, Out> Aggregator<In> createAggregator(
String named,
Combine.CombineFn<? super In, Inter, Out> combineFn) {
Aggregator aggregator = aggregators.get(named);
@SuppressWarnings("unchecked")
Aggregator<In> aggregator = (Aggregator<In>) aggregators.get(named);
if (aggregator == null) {
NamedAggregators.CombineFunctionState<? super In, Inter, Out> state = new NamedAggregators
.CombineFunctionState<>(combineFn);
@SuppressWarnings("unchecked")
NamedAggregators.CombineFunctionState<In, Inter, Out> state = new NamedAggregators
.CombineFunctionState<>((Combine.CombineFn<In, Inter, Out>) combineFn);
accum.add(new NamedAggregators(named, state));
aggregator = new SparkAggregator(state);
aggregator = new SparkAggregator<>(state);
aggregators.put(named, aggregator);
}
return aggregator;
Expand Down
Loading

0 comments on commit 6b62463

Please sign in to comment.