org.apache.beam
beam-sdks-java-extensions-sorter
diff --git a/sdks/java/extensions/pom.xml b/sdks/java/extensions/pom.xml
index 1222476ec904f..9bad6f598fe88 100644
--- a/sdks/java/extensions/pom.xml
+++ b/sdks/java/extensions/pom.xml
@@ -36,6 +36,7 @@
jackson
join-library
protobuf
+ sketching
sorter
diff --git a/sdks/java/extensions/sketching/pom.xml b/sdks/java/extensions/sketching/pom.xml
new file mode 100644
index 0000000000000..d39cc650ee302
--- /dev/null
+++ b/sdks/java/extensions/sketching/pom.xml
@@ -0,0 +1,102 @@
+
+
+
+
+ 4.0.0
+
+
+ org.apache.beam
+ beam-sdks-java-extensions-parent
+ 2.2.0-SNAPSHOT
+ ../pom.xml
+
+
+ beam-sdks-java-extensions-sketching
+ Apache Beam :: SDKs :: Java :: Extensions :: Sketching
+
+
+ 2.9.5
+ 3.1
+ 3.2
+ 2.2.0
+
+
+
+
+ org.apache.beam
+ beam-sdks-java-core
+
+
+
+ com.clearspring.analytics
+ stream
+ ${streamlib.version}
+
+
+
+ com.tdunning
+ t-digest
+ ${t-digest.version}
+
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+ org.apache.beam
+ beam-sdks-java-core
+ tests
+ test
+
+
+
+ org.apache.commons
+ commons-lang3
+ test
+
+
+
+ org.apache.commons
+ commons-math3
+ ${commons-math3.version}
+ test
+
+
+
+ org.apache.beam
+ beam-runners-direct-java
+ test
+
+
+
+ org.hamcrest
+ hamcrest-all
+ test
+
+
+
+ junit
+ junit
+ test
+
+
+
+
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/cardinality/ApproximateDistinct.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/cardinality/ApproximateDistinct.java
new file mode 100644
index 0000000000000..45bfcbc8765d3
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/cardinality/ApproximateDistinct.java
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.cardinality;
+
+import com.clearspring.analytics.stream.cardinality.CardinalityMergeException;
+import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code PTransform}s for computing the number of distinct elements in a {@code PCollection}, or
+ * the number of distinct values associated with each key in a {@code PCollection} of {@code KV}s.
+ *
+ * This class uses the HyperLogLog algorithm, and more precisely
+ * the improved version of Google (HyperLogLog+).
+ *
+ *
The implementation comes from Addthis' library Stream-lib :
+ * https://github.com/addthis/stream-lib
+ *
+ *
The original paper of the HyperLogLog is available here :
+ * http://algo.inria.fr/flajolet/Publications/FlFuGaMe07.pdf
+ *
+ *
A paper from the same authors to have a clearer view of the algorithm is available here :
+ * http://cscubs.cs.uni-bonn.de/2016/proceedings/paper-03.pdf
+ *
+ *
Google's HyperLogLog+ version is detailed in this paper :
+ * https://research.google.com/pubs/pub40671.html
+ */
+public class ApproximateDistinct {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ApproximateDistinct.class);
+
+ // do not instantiate
+ private ApproximateDistinct() {
+ }
+
+ /**
+ * A {@code PTransform} that takes an inputT {@code PCollection} of objects and returns a
+ * {@code PCollection} whose contents is a sketch which approximates
+ * the number of distinct element in the input {@code PCollection}.
+ *
+ * The parameter {@code p} controls the accuracy of the estimation. It represents
+ * the number of bits that will be used to index the elements,
+ * thus the number of different "buckets" in the HyperLogLog+ sketch.
+ *
In general, you can expect a relative error of about :
+ *
{@code 1.1 / sqrt(2^p)}
+ * For instance, the estimation {@code ApproximateDistinct.globally(12)}
+ * will have a relative error of about 2%.
+ *
Also keep in mind that {@code p} cannot be lower than 4,
+ * because the estimation would be too inaccurate.
+ *
See {@link ApproximateDistinctFn} for more details about the algorithm's principle.
+ *
+ * HyperLogLog+ version of Google uses a sparse representation in order to
+ * optimize memory and improve accuracy for small cardinalities.
+ * By calling this builder, you will not use the sparse representation.
+ * If you want to, see {@link ApproximateDistinctFn#withSparseRepresentation(int)}
+ *
+ *
Example of use
+ *
{@code PCollection input = ...;
+ * PCollection hllSketch = input
+ * .apply(ApproximateDistinct.globally(15));
+ * }
+ *
+ * @param type of elements being combined
+ * @param p number of bits for indexes in the HyperLogLogPlus
+ *
+ *
+ */
+ public static Combine.Globally globally(int p) {
+ return Combine.globally(ApproximateDistinctFn.create(p));
+ }
+
+ /**
+ * Do the same as {@link ApproximateDistinct#globally(int)},
+ * but with a default value of 18 for p.
+ *
+ * @param the type of the elements in the input {@code PCollection}
+ */
+ public static Combine.Globally globally() {
+ return globally(18);
+ }
+
+ /**
+ * A {@code PTransform} that takes an input {@code PCollection>} and returns a
+ * {@code PCollection>} that contains an output element mapping each
+ * distinct key in the input {@code PCollection} to a structure wrapping a {@link HyperLogLogPlus}
+ * which approximates the number of distinct values associated with that key in the input
+ * {@code PCollection}.
+ *
+ * The parameter {@code p} controls the accuracy of the estimation. It represents
+ * the number of bits that will be used to index the elements,
+ * thus the number of different "buckets" in the HyperLogLog+ sketch.
+ *
In general, you can expect a relative error of about :
+ *
{@code 1.1 / sqrt(2^p)}
+ * For instance, the estimation {@code ApproximateDistinct.globally(12)}
+ * will have a relative error of about 2%.
+ *
Also keep in mind that {@code p} cannot be lower than 4,
+ * because the estimation would be too inaccurate.
+ *
See {@link ApproximateDistinctFn} for more details about the algorithm's principle.
+ *
+ * HyperLogLog+ version of Google uses a sparse representation in order to
+ * optimize memory and improve accuracy for small cardinalities.
+ * By calling this builder, you will not use the sparse representation.
+ * If you want to, see {@link ApproximateDistinctFn#withSparseRepresentation(int)}
+ *
+ *
Example of use
+ *
{@code PCollection> input = ...;
+ * PCollection> hllSketch = input
+ * .apply(ApproximateDistinct.perKey(15));
+ * }
+ *
+ * @param p number of bits for indexes in the HyperLogLogPlus.
+ *
+ */
+ public static Combine.PerKey perKey(int p) {
+ return Combine.perKey(ApproximateDistinctFn.create(p));
+ }
+
+ /**
+ * Do the same as {@link ApproximateDistinct#globally(int)},
+ * but with a default value of 18 for p.
+ *
+ * @param the type of the keys in the input and output {@code PCollection}s
+ * @param the type of values in the input {@code PCollection}
+ */
+ public static Combine.PerKey perKey() {
+ return perKey(18);
+ }
+
+ /**
+ * A {@code Combine.CombineFn} that computes the stream into a {@link HyperLogLogPlus}
+ * sketch, useful as an argument to {@link Combine#globally} or {@link Combine#perKey}.
+ *
+ * The HyperLogLog algorithm relies on the principle that the overall cardinality
+ * can be estimated thanks to the longest run of starting 0s of the hashed elements.
+ * The longer the run is, the more unlikely it was to happen. Thus, the greater the number
+ * of elements that have been hashed before getting such a run.
+ *
Because this algorithm relies mainly on randomness, the stream is divided into buckets
+ * in order to reduce the variance of the estimation.
+ * Therefore, an estimation is applied on several samples and the overall estimation
+ * is then computed using an average (Harmonic mean).
+ *
+ * @param the type of the elements being combined
+ */
+ public static class ApproximateDistinctFn
+ extends Combine.CombineFn {
+
+ private final int p;
+
+ private final int sp;
+
+ private ApproximateDistinctFn(int p, int sp) {
+ this.p = p;
+ this.sp = sp;
+ }
+
+ /**
+ * Returns an {@code ApproximateDistinctFn} combiner with the given precision value p.
+ * This means that the input elements will be dispatched into 2^p buckets
+ * in order to estimate the cardinality.
+ *
+ * @param p precision value for the normal representation
+ * @param the type of the input {@code Pcollection}'s elements being combined.
+ */
+ public static ApproximateDistinctFn create(int p) {
+ if (p < 4) {
+ throw new IllegalArgumentException("p must be greater than 4");
+ }
+ return new ApproximateDistinctFn<>(p, 0);
+ }
+
+ /**
+ * Returns an {@code ApproximateDistinctFn} combiner with the precision value p
+ * of this combiner and the given precision value sp for the sparse representation,
+ * meaning that the combiner will be using a sparse representation at small cardinalities.
+ *
+ * The sparse representation does not initialize every buckets to 0 at the beginning.
+ * Indeed, for small cardinalities a lot of them will remain empty.
+ * Instead it builds a linked-list that grows when new indexes appear in the hashed values.
+ * To reduce collision of indexes and thus improve precision, we can define {@code sp > p}
+ * When the sparse representation would require more memory than the normal one,
+ * it is converted and the normal algorithm applies for the remaining elements.
+ *
+ *
WARNING : Choose sp such that {@code p <= sp <= 32}
+ *
+ *
Example of use :
+ *
{@code PCollection input = ...;
+ * PCollection hllSketch = input
+ * .apply(Combine.globally(ApproximateDistinct.ApproximateDistinctFn.create(p)
+ * .withSparseRepresentation(sp)));
+ * }
+ *
+ * @param sp the precision of HyperLogLog+' sparse representation
+ */
+ public ApproximateDistinctFn withSparseRepresentation(int sp) {
+ if (sp < p || sp > 32) {
+ throw new IllegalArgumentException("sp should be greater than p and lower than 32");
+ }
+ return new ApproximateDistinctFn<>(this.p, sp);
+ }
+
+ @Override
+ public HyperLogLogPlus createAccumulator() {
+ return new HyperLogLogPlus(p, sp);
+ }
+
+ @Override
+ public HyperLogLogPlus addInput(HyperLogLogPlus acc, InputT record) {
+ acc.offer(record);
+ return acc;
+ }
+
+ /**
+ * Output the whole structure so it can be queried, reused or stored easily.
+ */
+ @Override
+ public HyperLogLogPlus extractOutput(HyperLogLogPlus accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public HyperLogLogPlus mergeAccumulators(Iterable accumulators) {
+ HyperLogLogPlus mergedAccum = createAccumulator();
+ for (HyperLogLogPlus accum : accumulators) {
+ try {
+ mergedAccum.addAll(accum);
+ } catch (CardinalityMergeException e) {
+ // Should never happen because only HyperLogLogPlus accumulators are instantiated.
+ throw new IllegalStateException("The accumulators cannot be merged : " + e.getMessage());
+ // LOG.error("The accumulators cannot be merged : " + e.getMessage(), e);
+ }
+ }
+ return mergedAccum;
+ }
+
+ @Override
+ public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) {
+ return HyperLogLogPlusCoder.of();
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+ builder
+ .add(DisplayData.item("p", p)
+ .withLabel("precision"))
+ .add(DisplayData.item("sp", sp)
+ .withLabel("sparse representation precision"));
+ }
+ }
+
+ static class HyperLogLogPlusCoder extends CustomCoder {
+
+ private static final HyperLogLogPlusCoder INSTANCE = new HyperLogLogPlusCoder();
+
+ private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of();
+
+ public static HyperLogLogPlusCoder of() {
+ return INSTANCE;
+ }
+
+ @Override public void encode(HyperLogLogPlus value, OutputStream outStream) throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null HyperLogLogPlus sketch");
+ }
+ BYTE_ARRAY_CODER.encode(value.getBytes(), outStream);
+ }
+
+ @Override public HyperLogLogPlus decode(InputStream inStream) throws IOException {
+ return HyperLogLogPlus.Builder.build(BYTE_ARRAY_CODER.decode(inStream));
+ }
+
+ @Override public boolean isRegisterByteSizeObserverCheap(HyperLogLogPlus value) {
+ return true;
+ }
+
+ @Override protected long getEncodedElementByteSize(HyperLogLogPlus value) throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null HyperLogLogPlus sketch");
+ }
+ return value.sizeof();
+ }
+ }
+
+ /**
+ * Computes the precision based on the desired relative error.
+ *
+ * According to the paper, the mean squared error is bounded by the following formula :
+ *
b(m) / sqrt(m)
+ * Where m is the number of buckets used (p = log2(m))
+ * and b(m) < 1.106 for m > 16 (p > 4).
+ *
+ *
+ *
WARNING :
+ *
This does not mean relative error in the estimation can't be higher.
+ *
This only means that on average the relative error will be
+ * lower than the desired relative error.
+ *
Nevertheless, the more elements arrive in the {@code PCollection}, the lower
+ * the variation will be.
+ *
Indeed, this is like when you throw a dice millions of time :
+ * The relative frequency of each different result {1,2,3,4,5,6} will get closer to 1/6.
+ *
+ * @param relativeError the mean squared error should be in the interval ]0,1]
+ * @return the minimum precision p in order to have the desired relative error on average.
+ */
+ static long precisionForRelativeError(double relativeError) {
+ return Math.round(Math.ceil(Math.log(
+ Math.pow(1.106, 2.0)
+ / Math.pow(relativeError, 2.0))
+ / Math.log(2)));
+ }
+
+ /**
+ * @param p the precision i.e. the number of bits used for indexing the buckets
+ * @return the Mean squared error of the Estimation of cardinality to expect
+ * for the given value of p.
+ */
+ static double mseForP(int p) {
+ if (p < 4) {
+ return 1.0;
+ }
+ double betaM;
+ switch(p) {
+ case 4 : betaM = 1.156;
+ break;
+ case 5 : betaM = 1.2;
+ break;
+ case 6 : betaM = 1.104;
+ break;
+ case 7 : betaM = 1.096;
+ break;
+ default : betaM = 1.05;
+ break;
+ }
+ return betaM / Math.sqrt(Math.exp(p * Math.log(2)));
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/cardinality/package-info.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/cardinality/package-info.java
new file mode 100644
index 0000000000000..c97fc929e6bc9
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/cardinality/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Utilities for estimating cardinality with data sketching.
+ */
+package org.apache.beam.sdk.extensions.sketching.cardinality;
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/KMostFrequent.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/KMostFrequent.java
new file mode 100644
index 0000000000000..beeb5cc66a928
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/KMostFrequent.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.frequency;
+
+import com.clearspring.analytics.stream.Counter;
+import com.clearspring.analytics.stream.StreamSummary;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code PTransform}s for finding the k most frequent elements in a {@code PCollection}, or
+ * the k most frequent values associated with each key in a {@code PCollection} of {@code KV}s.
+ *
+ * This class uses the Space-Saving algorithm, introduced in this paper :
+ * https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
+ *
The implementation comes from Addthis' library Stream-lib : https://github.com/addthis/stream-lib
+ */
+public class KMostFrequent {
+
+ private static final Logger LOG = LoggerFactory.getLogger(KMostFrequent.class);
+
+ // do not instantiate
+ private KMostFrequent() {
+ }
+
+ /**
+ * A {@code PTransform} that takes a {@code PCollection} and returns a
+ * {@code PCollection>} whose contents is a sketch which contains
+ * the most frequent elements in the input {@code PCollection}.
+ *
+ * The {@code capacity} parameter controls the maximum number of elements the sketch
+ * can contain. Once this capacity is reached the least frequent element is dropped each
+ * time an incoming element is not already present in the sketch.
+ * Each element in the sketch is associated to a counter, that keeps track of the estimated
+ * frequency as well as the maximal potential error.
+ *
See {@link KMostFrequentFn} for more details.
+ *
+ *
Example of use
+ *
{@code PCollection input = ...;
+ * PCollection> ssSketch = input
+ * .apply(KMostFrequent.perKey(10000));
+ * }
+ *
+ * @param capacity the maximum number of distinct elements that the Stream Summary can keep
+ * track of at the same time
+ * @param the type of the elements in the input {@code PCollection}
+ */
+ public static Combine.Globally> globally(int capacity) {
+ return Combine.>globally(KMostFrequentFn.create(capacity));
+ }
+
+ /**
+ * A {@code PTransform} that takes an input {@code PCollection>} and returns a
+ * {@code PCollection>} that contains an output element mapping each
+ * distinct key in the input {@code PCollection} to a sketch which contains the most frequent
+ * values associated with that key in the input {@code PCollection}.
+ *
+ * The {@code capacity} parameter controls the maximum number of elements the sketch
+ * can contain. Once this capacity is reached the least frequent element is dropped each
+ * time an incoming element is not already present in the sketch.
+ * Each element in the sketch is associated to a counter, that keeps track of the estimated
+ * frequency as well as the maximal potential error.
+ *
See {@link KMostFrequentFn} for more details.
+ *
+ *
Example of use
+ *
{@code PCollection> input = ...;
+ * PCollection>> ssSketch = input
+ * .apply(KMostFrequent.globally(10000));
+ * }
+ *
+ * @param capacity the maximum number of distinct elements that the Stream Summary can keep
+ * track of at the same time
+ * @param the type of the keys in the input and output {@code PCollection}s
+ * @param the type of values in the input {@code PCollection}
+ */
+ public static Combine.PerKey> perKey(int capacity) {
+ if (capacity < 1) {
+ throw new IllegalArgumentException("The capacity must be strictly positive");
+ }
+ return Combine.>perKey(KMostFrequentFn.create(capacity));
+ }
+
+ /**
+ * A {@code Combine.CombineFn} that computes the stream into a {@link StreamSummary}
+ * sketch, useful as an argument to {@link Combine#globally} or {@link Combine#perKey}.
+ *
+ * The Space-Saving algorithm summarizes the stream by using a doubly linked-list of buckets
+ * ordered by the frequency value they represent. Each of these buckets contains a linked-list
+ * of counters which estimate the {@code count} for an element as well as the maximum
+ * overestimation {@code e} associated to it. The frequency cannot be overestimated.
+ *
+ *
An element is guaranteed to be in the top K most frequent if its guaranteed number of hits,
+ * i.e. {@code count - e}, is greater than the count of the element at the position k+1.
+ *
+ * @param the type of the elements being combined
+ */
+ public static class KMostFrequentFn
+ extends Combine.CombineFn, StreamSummary> {
+
+ private int capacity;
+
+ private KMostFrequentFn(int capacity) {
+ this.capacity = capacity;
+ }
+
+ public static KMostFrequentFn create(int capacity) {
+ if (capacity <= 0) {
+ throw new IllegalArgumentException("Capacity must be greater than 0.");
+ }
+ return new KMostFrequentFn<>(capacity);
+ }
+
+ @Override
+ public StreamSummary createAccumulator() {
+ return new StreamSummary<>(this.capacity);
+ }
+
+ @Override
+ public StreamSummary addInput(StreamSummary accumulator, T element) {
+ accumulator.offer(element, 1);
+ return accumulator;
+ }
+
+ @Override
+ public StreamSummary mergeAccumulators(
+ Iterable> accumulators) {
+ Iterator> it = accumulators.iterator();
+ if (it.hasNext()) {
+ StreamSummary mergedAccum = it.next();
+ while (it.hasNext()) {
+ StreamSummary other = it.next();
+ List> top = other.topK(capacity);
+ for (Counter counter : top) {
+ mergedAccum.offer(counter.getItem(), (int) counter.getCount());
+ }
+ }
+ return mergedAccum;
+ }
+ return null;
+ }
+
+ @Override
+ public StreamSummary extractOutput(StreamSummary accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public Coder> getAccumulatorCoder(CoderRegistry registry,
+ Coder inputCoder) {
+ return new StreamSummaryCoder<>();
+ }
+
+ @Override
+ public Coder> getDefaultOutputCoder(CoderRegistry registry,
+ Coder inputCoder) {
+ return new StreamSummaryCoder<>();
+ }
+ }
+
+ static class StreamSummaryCoder extends CustomCoder> {
+
+ private static final Coder BYTE_ARRAY_CODER = ByteArrayCoder.of();
+
+ @Override
+ public void encode(StreamSummary value, OutputStream outStream) throws IOException {
+ BYTE_ARRAY_CODER.encode(value.toBytes(), outStream);
+ }
+
+ @Override
+ public StreamSummary decode(InputStream inStream) throws IOException {
+ try {
+ return new StreamSummary<>(BYTE_ARRAY_CODER.decode(inStream));
+ } catch (ClassNotFoundException e) {
+ LOG.error(e.getMessage()
+ + " The Stream Summary sketch can't be decoded from the input stream", e);
+ }
+ return null;
+ }
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/SketchFrequencies.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/SketchFrequencies.java
new file mode 100644
index 0000000000000..91268ad46170c
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/SketchFrequencies.java
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.frequency;
+
+import com.clearspring.analytics.stream.frequency.CountMinSketch;
+import com.clearspring.analytics.stream.frequency.FrequencyMergeException;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.Iterator;
+
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code PTransform}s that records an estimation of the frequency of each element in a
+ * {@code PCollection}, or the occurrences of values associated with each key in a
+ * {@code PCollection} of {@code KV}s.
+ *
+ * This class uses the Count-min Sketch structure. The papers and other useful information
+ * about it is available on this website : https://sites.google.com/site/countminsketch/
+ *
The implementation comes from Apache Spark :
+ * https://github.com/apache/spark/tree/master/common/sketch
+ */
+public class SketchFrequencies {
+
+ private static final Logger LOG = LoggerFactory.getLogger(SketchFrequencies.class);
+
+ // do not instantiate
+ private SketchFrequencies() {
+ }
+
+ /**
+ * A {@code PTransform} that takes an input {@code PCollection} and returns a
+ * {@code PCollection} whose contents is a Count-min sketch that allows to query
+ * the number of hits for a specific element in the input {@code PCollection}.
+ *
+ * The {@code seed} parameters will be used to randomly generate different hash functions.
+ * Thus, the result can be different for the same stream in different seeds are used.
+ * The {@code seed} parameter will be used to generate a and b for each hash function.
+ *
The Count-min sketch size is constant through the process so the memory use is fixed.
+ * However, the dimensions are directly linked to the accuracy.
+ *
By default, the relative error is set to 1% with 1% probability that the estimation
+ * breaks this limit.
+ *
Also keep in mind that this algorithm works well on highly skewed data but gives poor
+ * results if the elements are evenly distributed.
+ *
+ *
See {@link CountMinSketchFn#withAccuracy(double, double)} in order to tune the parameters.
+ *
Also see {@link CountMinSketchFn} for more details about the algorithm's principle.
+ *
+ *
Example of use:
+ *
{@code
+ * PCollection pc = ...;
+ * PCollection countMinSketch =
+ * pc.apply(SketchFrequencies.globally(1234));
+ * }
+ *
+ * Also see {@link CountMinSketchFn} for more details about the algorithm's principle.
+ *
+ * @param seed the seed used for generating randomly different hash functions
+ */
+ public static Combine.Globally globally(int seed) {
+ return Combine.globally(CountMinSketchFn
+ .create(seed).withAccuracy(0.001, 0.99));
+ }
+
+ /**
+ * A {@code PTransform} that takes an input {@code PCollection>} and
+ * returns a {@code PCollection>} that contains an output element mapping
+ * each distinct key in the input {@code PCollection} to a structure that allows to query the
+ * count of a specific element associated with that key in the input {@code PCollection}.
+ *
+ * The {@code seed} parameters will be used to randomly generate different hash functions.
+ * Thus, the result can be different for the same stream in different seeds are used.
+ * The {@code seed} parameter will be used to generate a and b for each hash function.
+ *
The Count-min sketch size is constant through the process so the memory use is fixed.
+ * However, the dimensions are directly linked to the accuracy.
+ *
By default, the relative error is set to 1% with 1% probability that the estimation
+ * breaks this limit.
+ *
also keep in mind that this algorithm works well on highly skewed data but gives poor
+ * results if the elements are evenly distributed.
+ *
+ *
See {@link CountMinSketchFn#withAccuracy(double, double)} in order to tune the parameters.
+ *
Also see {@link CountMinSketchFn} for more details about the algorithm's principle.
+ *
+ *
Example of use:
+ *
{@code
+ * PCollection> pc = ...;
+ * PCollection> countMinSketch =
+ * pc.apply(SketchFrequencies.perKey(1234));
+ * }
+ *
+ * @param seed the seed used for generating different hash functions
+ * @param the type of the keys in the input and output {@code PCollection}s
+ */
+ public static Combine.PerKey perKey(int seed) {
+ return Combine.perKey(CountMinSketchFn
+ .create(seed).withAccuracy(0.001, 0.99));
+ }
+
+ /**
+ * A {@code Combine.CombineFn} that computes the {@link CountMinSketch} Structure
+ * of an {@code Iterable} of Strings, useful as an argument to {@link Combine#globally} or
+ * {@link Combine#perKey}.
+ *
+ * When an element is added to the Count-min sketch, it is mapped to one column in each
+ * row using different hash functions, and a counter is updated in each column.
+ *
Collisions will happen as far as the number of distinct elements in the stream is greater
+ * than the width of the sketch. Each counter might be associated to many items so the frequency
+ * of an element is always overestimated. On average the relative error on a counter is bounded,
+ * but some counters can be very inaccurate.
+ *
That's why different hash functions are used to map the same element to different
+ * counters. Thus, the overestimation for each counter will differ as there will be different
+ * collisions, and one will probably be less inaccurate than the average.
+ *
+ *
Both the average relative error and the probability to have an estimation overcoming this
+ * error can be computed by knowing the dimensions of the sketch, and vice-versa.
+ * Thus, for Count-min sketch with 10 000 columns and 7 rows, the relative error should not be no
+ * more than 0.02% in 99% of the cases.
+ *
+ */
+ public static class CountMinSketchFn
+ extends Combine.CombineFn {
+
+ private final int depth;
+
+ private final int width;
+
+ private final int seed;
+
+ private CountMinSketchFn(double eps, double confidence, int seed) {
+ this.width = (int) Math.ceil(2 / eps);
+ this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
+ this.seed = seed;
+ }
+
+ private CountMinSketchFn(int width, int depth, int seed) {
+ this.width = width;
+ this.depth = depth;
+ this.seed = seed;
+ }
+
+ /**
+ * Returns an {@code CountMinSketchFn} combiner that will have a Count-min sketch
+ * which will estimate the frequencies with about 1% of error guaranteed at 99%.
+ * the resulting dimensions are 2000 x 7. It will stay constant during all the aggregation.
+ *
+ * the {@code seed} parameters is used to generate different hash functions of the form :
+ *
a * i + b % p % width ,
+ * where a, b are chosen randomly and p is a prime number larger than the maximum i value.
+ *
+ * Example of use:
+ *
1) Globally :
+ *
{@code
+ * PCollection pc = ...;
+ * PCollection countMinSketch =
+ * pc.apply(Combine.globally(CountMinSketchFn.create(1234));
+ * }
+ *
2) Per key :
+ * {@code
+ * PCollection> pc = ...;
+ * PCollection> countMinSketch =
+ * pc.apply(Combine.perKey(CountMinSketchFn.create(1234));
+ * }
+ *
+ * @param seed the seed used for generating different hash functions
+ */
+ public static CountMinSketchFn create(int seed) {
+ return new CountMinSketchFn(0.001, 0.99, seed);
+ }
+
+ /**
+ * Returns an {@code CountMinSketchFn} combiner that will have a Count-min sketch of
+ * dimensions {@code width x depth}, that will stay constant during all the aggregation.
+ * This method can only be applied from a {@link CountMinSketchFn} already created with the
+ * method {@link CountMinSketchFn#create(int)}.
+ *
+ * The greater the {@code width}, the lower the expected relative error {@code epsilon} :
+ *
{@code epsilon = 2 / width}
+ *
+ * The greater the {@code depth}, the lower the probability to actually have
+ * a greater relative error than expected.
+ *
{@code confidence = 1 - 2^-depth}
+ *
+ * Example of use:
+ *
1) Globally :
+ *
{@code
+ * PCollection pc = ...;
+ * PCollection countMinSketch =
+ * pc.apply(Combine.globally(CountMinSketchFn.create(1234)
+ * withDimensions(10000, 7));
+ * }
+ *
2) Per key :
+ * {@code
+ * PCollection> pc = ...;
+ * PCollection> countMinSketch =
+ * pc.apply(Combine.perKey(CountMinSketchFn.create(1234)
+ * withDimensions(10000, 7)););
+ * }
+ *
+ * @param width Number of columns, i.e. number of counters for the stream.
+ * @param depth Number of lines, i.e. number of hash functions
+ */
+ public CountMinSketchFn withDimensions(int width, int depth) {
+ if (width <= 0 || depth <= 0) {
+ throw new IllegalArgumentException("depth and width must be positive.");
+ }
+ return new CountMinSketchFn(width, depth, this.seed);
+ }
+
+ /**
+ * Returns an {@code CountMinSketchFn} combiner that will be as accurate as specified. The
+ * relative error {@code epsilon} can be guaranteed only with a certain {@code confidence},
+ * which has to be between 0 and 1 (1 being of course impossible). Those parameters will
+ * determine the size of the Count-min sketch in which the elements will be aggregated.
+ * This method can only be applied to a {@link CountMinSketchFn} already created with the
+ * method {@link CountMinSketchFn#create(int)}.
+ *
+ * The lower the {@code epsilon} value, the greater the width.
+ *
{@code width = (int) 2 / epsilon)}
+ *
+ * The greater the confidence, the greater the depth.
+ *
{@code depth = (int) -log2(1 - confidence)}
+ *
+ * Example of use:
+ *
1) Globally :
+ *
{@code
+ * PCollection pc = ...;
+ * PCollection countMinSketch =
+ * pc.apply(Combine.globally(CountMinSketchFn.create(1234)
+ * withDimensions(0.001, 0.99));
+ * }
+ *
2) Per key :
+ * {@code
+ * PCollection> pc = ...;
+ * PCollection> countMinSketch =
+ * pc.apply(Combine.perKey(CountMinSketchFn.create(1234)
+ * withAccuracy(0.001, 0.99)););
+ * }
+ *
+ *
+ * @param epsilon the relative error of the result
+ * @param confidence the confidence in the result to not overcome the relative error
+ */
+ public CountMinSketchFn withAccuracy(double epsilon, double confidence) {
+ return new CountMinSketchFn(epsilon, confidence, this.seed);
+ }
+
+ @Override public CountMinSketch createAccumulator() {
+ return new CountMinSketch(this.depth, this.width, this.seed);
+ }
+
+ @Override public CountMinSketch addInput(CountMinSketch accumulator, String element) {
+ accumulator.add(element, 1);
+ return accumulator;
+ }
+
+ @Override public CountMinSketch mergeAccumulators(Iterable accumulators) {
+ Iterator it = accumulators.iterator();
+ if (!it.hasNext()) {
+ return new CountMinSketch(seed, width, depth);
+ }
+ CountMinSketch merged = it.next();
+ try {
+ while (it.hasNext()) {
+ merged = CountMinSketch.merge(merged, it.next());
+ }
+ } catch (FrequencyMergeException e) {
+ // Should never happen because all the accumulators created are of the same type.
+ LOG.error(e.getMessage(), e);
+ }
+ return merged;
+ }
+
+ @Override public CountMinSketch extractOutput(CountMinSketch accumulator) {
+ return accumulator;
+ }
+
+ @Override public Coder getAccumulatorCoder(CoderRegistry registry,
+ Coder inputCoder) {
+ return new CountMinSketchCoder();
+ }
+
+ @Override public Coder getDefaultOutputCoder(CoderRegistry registry,
+ Coder inputCoder) throws CannotProvideCoderException {
+ return new CountMinSketchCoder();
+ }
+
+ @Override public CountMinSketch defaultValue() {
+ return new CountMinSketch(1, 1, 1);
+ }
+ }
+
+ static class CountMinSketchCoder extends CustomCoder {
+
+ private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of();
+
+ @Override public void encode(CountMinSketch value, OutputStream outStream) throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null Count-min Sketch");
+ }
+ BYTE_ARRAY_CODER.encode(CountMinSketch.serialize(value), outStream);
+ }
+
+ @Override public CountMinSketch decode(InputStream inStream) throws IOException {
+ return CountMinSketch.deserialize(BYTE_ARRAY_CODER.decode(inStream));
+ }
+
+ @Override public boolean consistentWithEquals() {
+ return false;
+ }
+
+ @Override public boolean isRegisterByteSizeObserverCheap(CountMinSketch value) {
+ return true;
+ }
+
+ @Override protected long getEncodedElementByteSize(CountMinSketch value) throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null Count-min Sketch");
+ } else {
+ // depth and width as computed in the CountMinSketch constructor from the relative error and
+ // confidence.
+ int width = (int) Math.ceil(2 / value.getRelativeError());
+ int depth = (int) Math.ceil(-Math.log(1 - value.getConfidence()) / Math.log(2));
+
+ // 8L is for the sketch's size (long)
+ // 4L * 2 is for depth and width (ints)
+ // 8L * depth * (width + 1) is a factorization for the sizes of table (long[depth][width])
+ // and hashA (long[depth])
+ return 8L + 4L * 2 + 8L * depth * (width + 1);
+ }
+ }
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/package-info.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/package-info.java
new file mode 100644
index 0000000000000..bcdfc4d819b56
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/frequency/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Utilities for estimating data frequencies with data sketching.
+ */
+package org.apache.beam.sdk.extensions.sketching.frequency;
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/package-info.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/package-info.java
new file mode 100644
index 0000000000000..481d86ff8d5d6
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Utilities for estimating data sketches with Beam.
+ */
+package org.apache.beam.sdk.extensions.sketching;
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/quantiles/TDigestQuantiles.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/quantiles/TDigestQuantiles.java
new file mode 100644
index 0000000000000..6dcca660fe7a0
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/quantiles/TDigestQuantiles.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.quantiles;
+
+import com.tdunning.math.stats.MergingDigest;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.transforms.Combine;
+
+/**
+ * {@code PTransform}s for getting information about quantiles in the input {@code PCollection},
+ * or the occurrences of values associated with each key in a {@code PCollection} of {@code KV}s.
+ *
+ * This class uses the T-Digest structure, an improvement of Q-Digest made by Ted Dunning.
+ * The paper and implementation are available on his Github profile :
+ * https://github.com/tdunning/t-digest
+ *
+ *
For Your Information :
+ *
The current release of t-digest (3.1) has non-serializable implementations. This problem
+ * has been issued and corrected on the master. The new release should be available soon.
+ *
Until then a wrapper is used, see {@link SerializableTDigest}.
+ */
+public class TDigestQuantiles {
+
+ // do not instantiate
+ private TDigestQuantiles() {
+ }
+
+ /**
+ * A {@code PTransform} that takes an input {@code PCollection} and returns a
+ * {@code PCollection} whose contents is a TDigest sketch for querying
+ * the quantiles of the set of input {@code PCollection}'s elements.
+ *
+ * The compression factor controls the accuracy of the queries. For a compression equal to C
+ * the relative error will be at most 3/C.
+ *
+ *
Example of use :
+ *
{@code PCollection input = ...
+ * PCollection sketch = input.apply(TDigestQuantiles.globally(1000));
+ * }
+ *
+ * @param compression the compression factor guarantees a relative error of at most
+ * {@code 3 / compression} on quantiles.
+ */
+ public static Combine.Globally globally(int compression) {
+ return Combine.globally(TDigestQuantilesFn.create(compression));
+ }
+
+ /**
+ * A {@code PTransform} that takes an input {@code PCollection>} and returns a
+ * {@code PCollection>} mapping each distinct key in the input
+ * {@code PCollection} to the TDigest sketch for querying the quantiles of the set of
+ * elements associated with that key in the input {@code PCollection}.
+ *
+ * The compression factor controls the accuracy of the queries. For a compression equal to C
+ * the relative error will be at most 3/C.
+ *
+ *
Example of use :
+ *
{@code PCollection> input = ...
+ * PCollection> sketch = input
+ * .apply(TDigestQuantiles.perKey(1000));
+ * }
+ *
+ * @param compression the compression factor guarantees a relative error of at most
+ * {@code 3 / compression} on quantiles.
+ * @param the type of the keys
+ */
+ public static Combine.PerKey perKey(int compression) {
+ return Combine.perKey(TDigestQuantilesFn.create(compression));
+ }
+
+ /**
+ * A {@code Combine.CombineFn} that computes the {@link SerializableTDigest} structure
+ * of an {@code Iterable} of Doubles, useful as an argument to {@link Combine#globally} or
+ * {@link Combine#perKey}.
+ */
+ public static class TDigestQuantilesFn
+ extends Combine.CombineFn {
+
+ private final int compression;
+
+ private TDigestQuantilesFn(int compression) {
+ this.compression = compression;
+ }
+
+ public static TDigestQuantilesFn create(int compression) {
+ if (compression > 0) {
+ return new TDigestQuantilesFn(compression);
+ }
+ throw new IllegalArgumentException("Compression factor should be greater than 0.");
+ }
+
+ @Override public SerializableTDigest createAccumulator() {
+ return new SerializableTDigest(compression);
+ }
+
+ @Override public SerializableTDigest addInput(SerializableTDigest accum, Double value) {
+ accum.add(value);
+ return accum;
+ }
+
+ @Override public SerializableTDigest extractOutput(SerializableTDigest accum) {
+ return accum;
+ }
+
+ @Override public SerializableTDigest mergeAccumulators(
+ Iterable accumulators) {
+ return SerializableTDigest.merge(accumulators);
+ }
+
+ @Override public Coder getAccumulatorCoder(CoderRegistry registry,
+ Coder inputCoder) {
+ return SerializableTDigestCoder.of();
+ }
+
+ @Override public Coder getDefaultOutputCoder(CoderRegistry registry,
+ Coder inputCoder) {
+ return SerializableTDigestCoder.of();
+ }
+
+ @Override public SerializableTDigest defaultValue() {
+ return new SerializableTDigest(10);
+ }
+ }
+
+ /**
+ * This class is a wrapper for MergingDigest class because it is not serializable.
+ * The problem has been issued and corrected on 3.2 version of Ted Dunning's implementation :
+ * https://github.com/tdunning/t-digest
+ * However, this version has not been released yet so the issue is still up-to-date.
+ */
+ public static class SerializableTDigest implements Serializable {
+
+ private transient MergingDigest sketch;
+
+ public SerializableTDigest(int compression) {
+ sketch = new MergingDigest(compression);
+ }
+
+ private SerializableTDigest(MergingDigest sketch) {
+ this.sketch = sketch;
+ }
+
+ public void add(Double input) {
+ this.sketch.add(input, 1);
+ }
+
+ public void encode(OutputStream out) throws IOException {
+ ByteBuffer buf = ByteBuffer.allocate(sketch.smallByteSize());
+ sketch.asSmallBytes(buf);
+ ByteArrayCoder.of().encode(buf.array(), out);
+ }
+
+ public static SerializableTDigest decode(InputStream in) throws IOException {
+ byte[] bytes = ByteArrayCoder.of().decode(in);
+ return new SerializableTDigest(MergingDigest.fromBytes(ByteBuffer.wrap(bytes)));
+ }
+
+ public static SerializableTDigest merge(Iterable list) {
+ Iterator it = list.iterator();
+ if (!it.hasNext()) {
+ return null;
+ }
+ SerializableTDigest mergedDigest = it.next();
+ while (it.hasNext()) {
+ SerializableTDigest next = it.next();
+ if (next.getSketch().centroids().size() > 1) {
+ mergedDigest.sketch.add(next.sketch);
+ }
+ }
+ return mergedDigest;
+ }
+
+ public MergingDigest getSketch() {
+ return this.sketch;
+ }
+ }
+
+ static class SerializableTDigestCoder extends CustomCoder {
+
+ private static final SerializableTDigestCoder INSTANCE = new SerializableTDigestCoder();
+
+ public static SerializableTDigestCoder of() {
+ return INSTANCE;
+ }
+
+ @Override public void encode(SerializableTDigest value, OutputStream outStream)
+ throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null T-Digest sketch");
+ }
+ value.encode(outStream);
+ }
+
+ @Override public SerializableTDigest decode(InputStream inStream) throws IOException {
+ return SerializableTDigest.decode(inStream);
+ }
+
+ @Override public boolean isRegisterByteSizeObserverCheap(SerializableTDigest value) {
+ return true;
+ }
+
+ @Override protected long getEncodedElementByteSize(SerializableTDigest value)
+ throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null T-Digest sketch");
+ }
+ return value.getSketch().smallByteSize();
+ }
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/quantiles/package-info.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/quantiles/package-info.java
new file mode 100644
index 0000000000000..75c5b4c47dbec
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/quantiles/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Utilities for estimating quantiles with data sketching.
+ */
+package org.apache.beam.sdk.extensions.sketching.quantiles;
diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/cardinality/ApproximateDistinctTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/cardinality/ApproximateDistinctTest.java
new file mode 100644
index 0000000000000..a78061961e927
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/cardinality/ApproximateDistinctTest.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.cardinality;
+
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Tests for {@link ApproximateDistinct}.
+ */
+public class ApproximateDistinctTest implements Serializable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ApproximateDistinctTest.class);
+
+ @Rule
+ public final transient TestPipeline tp = TestPipeline.create();
+
+ @Test
+ public void smallCardinality() {
+ final int smallCard = 1000;
+ List small = new ArrayList<>();
+ for (int i = 0; i < smallCard; i++) {
+ small.add(i);
+ }
+
+ final int p = 6;
+ final Double expectedErr = 1.104 / Math.sqrt(p);
+
+ PCollection cardinality = tp.apply(Create. of(small))
+ .apply(ApproximateDistinct.globally(p))
+ .apply(ParDo.of(new RetrieveDistinct()));
+
+ PAssert.thatSingleton("Not Accurate Enough", cardinality)
+ .satisfies(new SerializableFunction() {
+ @Override
+ public Void apply(Long input) {
+ boolean isAccurate = Math.abs(input - smallCard) / smallCard < expectedErr;
+ Assert.assertTrue("not accurate enough : \nExpected Cardinality : "
+ + smallCard + "\nComputed Cardinality : " + input,
+ isAccurate);
+ return null;
+ }
+ });
+ tp.run();
+ }
+
+ @Test
+ public void createSparse4BigCardinality() {
+ final int cardinality = 15000;
+ final int p = 15;
+ final int sp = 20;
+ final Double expectedErr = 1.04 / Math.sqrt(p);
+
+ List stream = new ArrayList<>();
+ for (int i = 1; i <= cardinality; i++) {
+ stream.addAll(Collections.nCopies(2, i));
+ }
+ Collections.shuffle(stream);
+
+ PCollection res = tp.apply(Create.of(stream))
+ .apply(Combine.globally(ApproximateDistinct.ApproximateDistinctFn.create(p)
+ .withSparseRepresentation(sp)))
+ .apply(ParDo.of(new RetrieveDistinct()));
+
+ PAssert.that("Verify Accuracy", res)
+ .satisfies(new VerifyAccuracy(cardinality, expectedErr));
+
+ tp.run();
+ }
+
+ @Test
+ public void perKey() {
+ final int cardinality = 1000;
+ final int p = 15;
+ final Double expectedErr = 1.04 / Math.sqrt(p);
+ List stream = new ArrayList<>();
+ for (int i = 1; i <= cardinality; i++) {
+ stream.addAll(Collections.nCopies(2, i));
+ }
+ Collections.shuffle(stream);
+
+ PCollection results = tp.apply(Create.of(stream))
+ .apply(WithKeys.of(1))
+ .apply(ApproximateDistinct.perKey(p))
+ .apply(Values.create())
+ .apply(ParDo.of(new RetrieveDistinct()));
+
+ PAssert.that("Verify Accuracy", results)
+ .satisfies(new VerifyAccuracy(cardinality, expectedErr));
+ tp.run();
+ }
+
+ static class RetrieveDistinct extends DoFn {
+ @ProcessElement
+ public void apply(ProcessContext c) {
+ Long card = c.element().cardinality();
+ LOG.debug("Number of distinct Elements : " + card);
+ c.output(card);
+ }
+ }
+
+ @Test
+ public void testCoder() throws Exception {
+ HyperLogLogPlus hllp = new HyperLogLogPlus(12, 18);
+ for (int i = 0; i < 10; i++) {
+ hllp.offer(i);
+ }
+ CoderProperties.coderDecodeEncodeEqual(
+ ApproximateDistinct.HyperLogLogPlusCoder.of(), hllp);
+ }
+
+ @Test
+ public void testDisplayData() {
+ final Combine.Globally specifiedPrecision =
+ ApproximateDistinct.globally(23);
+
+ assertThat(DisplayData.from(specifiedPrecision), hasDisplayItem("p", 23));
+ assertThat(DisplayData.from(specifiedPrecision), hasDisplayItem("sp", 0));
+
+ }
+
+ class VerifyAccuracy implements SerializableFunction, Void> {
+
+ private final int expectedCard;
+
+ private final double expectedError;
+
+ VerifyAccuracy(int expectedCard, double expectedError) {
+ this.expectedCard = expectedCard;
+ this.expectedError = expectedError;
+ }
+
+ @Override
+ public Void apply(Iterable input) {
+ for (Long estimate : input) {
+ boolean isAccurate = Math.abs(estimate - expectedCard) / expectedCard < expectedError;
+ Assert.assertTrue(
+ "not accurate enough : \nExpected Cardinality : " + expectedCard
+ + "\nComputed Cardinality : " + estimate,
+ isAccurate);
+ }
+ return null;
+ }
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/frequency/KMostFrequentTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/frequency/KMostFrequentTest.java
new file mode 100644
index 0000000000000..e6a0d9fbdb99d
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/frequency/KMostFrequentTest.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.frequency;
+
+import com.clearspring.analytics.stream.Counter;
+import com.clearspring.analytics.stream.StreamSummary;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.extensions.sketching.frequency.KMostFrequent.StreamSummaryCoder;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Tests for {@link KMostFrequent}.
+ */
+public class KMostFrequentTest {
+
+ @Rule
+ public final transient TestPipeline tp = TestPipeline.create();
+
+ private List smallStream = Arrays.asList(
+ 1,
+ 2, 2,
+ 3, 3, 3,
+ 4, 4, 4, 4,
+ 5, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 6,
+ 7, 7, 7, 7, 7, 7, 7,
+ 8, 8, 8, 8, 8, 8, 8, 8,
+ 9, 9, 9, 9, 9, 9, 9, 9, 9,
+ 10, 10, 10, 10, 10, 10, 10, 10, 10, 10);
+
+ @Test
+ public void smallStream() {
+ Collections.shuffle(smallStream, new Random(1234));
+ PCollection col = tp.apply(Create.of(smallStream))
+ .apply(KMostFrequent.globally(10))
+ .apply("For print Big Top", ParDo.of(new OutputTopK(3)))
+ .setCoder(BigEndianIntegerCoder.of());
+ PAssert.that(col).containsInAnyOrder(10, 9, 8);
+ tp.run();
+ }
+
+ @Test
+ public void bigStream() {
+ List bigStream = new ArrayList<>();
+ // 1000 * 1, 2000 * 2, 3000 * 3, etc
+ for (int i = 1; i < 11; i++) {
+ bigStream.addAll(Collections.nCopies(i * 1000, i));
+ }
+ Collections.shuffle(bigStream, new Random(1234));
+ PCollection col = tp.apply(Create.of(bigStream))
+ .apply(KMostFrequent.globally(10))
+ .apply("For print Big Top", ParDo.of(new OutputTopK(5)))
+ .setCoder(BigEndianIntegerCoder.of());
+ PAssert.that(col).containsInAnyOrder(10, 9, 8, 7, 6);
+ tp.run();
+ }
+
+ @Test
+ public void testCoder() throws Exception {
+ StreamSummary ssSketch = new StreamSummary<>(5);
+ for (Integer i : smallStream) {
+ ssSketch.offer(i);
+ }
+ Assert.assertTrue(encodeDecode(ssSketch));
+ }
+
+ private boolean encodeDecode(StreamSummary ss) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ StreamSummaryCoder ssCoder = new StreamSummaryCoder<>();
+
+ ssCoder.encode(ss, baos);
+ byte[] bytes = baos.toByteArray();
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
+ StreamSummary decoded = ssCoder.decode(bais);
+
+ return ss.toString().equals(decoded.toString());
+ }
+
+ private static class OutputTopK extends DoFn, T> {
+
+ private int k = 0;
+
+ private OutputTopK(int k) {
+ this.k = k;
+ }
+
+ @ProcessElement
+ public void apply(ProcessContext c) {
+ List> li = c.element().topK(k);
+ for (Counter counter : li) {
+ c.output(counter.getItem());
+ }
+ }
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/frequency/SketchFrequenciesTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/frequency/SketchFrequenciesTest.java
new file mode 100644
index 0000000000000..a706fbd1b5ec6
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/frequency/SketchFrequenciesTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.frequency;
+
+import com.clearspring.analytics.stream.frequency.CountMinSketch;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.beam.sdk.extensions.sketching.frequency.SketchFrequencies.CountMinSketchFn;
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.ToString;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Tests for {@link SketchFrequencies}.
+ */
+public class SketchFrequenciesTest {
+
+ @Rule public final transient TestPipeline tp = TestPipeline.create();
+
+ private List smallStream = Arrays.asList(
+ 1L,
+ 2L, 2L,
+ 3L, 3L, 3L,
+ 4L, 4L, 4L, 4L,
+ 5L, 5L, 5L, 5L, 5L,
+ 6L, 6L, 6L, 6L, 6L, 6L,
+ 7L, 7L, 7L, 7L, 7L, 7L, 7L,
+ 8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L,
+ 9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L,
+ 10L, 10L, 10L, 10L, 10L, 10L, 10L, 10L, 10L, 10L);
+
+ @Test
+ public void defaultConstruct() {
+ PCollection col = tp.apply(Create.of(smallStream))
+ .apply(ToString.elements())
+ .apply(SketchFrequencies.globally(1234))
+ .apply(ParDo.of(new QueryFrequencies()));
+ PAssert.that(col).containsInAnyOrder(10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L);
+ tp.run();
+ }
+
+ @Test
+ public void createDimensions() {
+ CountMinSketchFn cmsFn = CountMinSketchFn.create(1234).withDimensions(200, 3);
+ PCollection col = tp.apply(Create.of(smallStream))
+ .apply(ToString.elements())
+ .apply(Combine.globally(cmsFn))
+ .apply(ParDo.of(new QueryFrequencies()));
+ PAssert.that(col).containsInAnyOrder(10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L);
+ tp.run();
+ }
+
+ @Test
+ public void createAccuracy() {
+ CountMinSketchFn cmsFn = CountMinSketchFn.create(1234).withAccuracy(0.01, 0.80);
+ PCollection col = tp.apply(Create.of(smallStream))
+ .apply(ToString.elements())
+ .apply(Combine.globally(cmsFn))
+ .apply(ParDo.of(new QueryFrequencies()));
+ PAssert.that(col).containsInAnyOrder(10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L);
+ tp.run();
+ }
+
+ @Test
+ public void createPerKey() {
+ CountMinSketchFn cmsFn = CountMinSketchFn.create(1234);
+ PCollection col = tp.apply(Create.of(smallStream))
+ .apply(ToString.elements())
+ .apply(WithKeys.of(1))
+ .apply(Combine.perKey(cmsFn))
+ .apply(Values.create())
+ .apply(ParDo.of(new QueryFrequencies()));
+ PAssert.that(col).containsInAnyOrder(10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L);
+ tp.run();
+ }
+
+ @Test
+ public void testCoder() throws Exception {
+ CountMinSketch cMSketch = new CountMinSketch(200, 7, 12345);
+ for (long i = 0L; i < 10L; i++) {
+ cMSketch.add(i, 1);
+ }
+ CoderProperties.coderDecodeEncodeEqual(
+ new SketchFrequencies.CountMinSketchCoder(), cMSketch);
+ }
+
+ static class QueryFrequencies extends DoFn {
+ @ProcessElement public void processElement(ProcessContext c) {
+ CountMinSketch countMinSketch = c.element();
+ for (Long i = 1L; i < 11L; i++) {
+ c.output(countMinSketch.estimateCount(i.toString()));
+ }
+ }
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/quantiles/TDigestQuantilesTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/quantiles/TDigestQuantilesTest.java
new file mode 100644
index 0000000000000..b5b762ab9fea4
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/quantiles/TDigestQuantilesTest.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching.quantiles;
+
+import com.tdunning.math.stats.Centroid;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.beam.sdk.extensions.sketching.quantiles.TDigestQuantiles.SerializableTDigest;
+import org.apache.beam.sdk.extensions.sketching.quantiles.TDigestQuantiles.SerializableTDigestCoder;
+import org.apache.beam.sdk.extensions.sketching.quantiles.TDigestQuantiles.TDigestQuantilesFn;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+/**
+ * Tests for {@link TDigestQuantiles}.
+ */
+public class TDigestQuantilesTest {
+
+ @Rule public final transient TestPipeline tp = TestPipeline.create();
+
+ private static final List stream = generateStream();
+
+ private static final int size = 999;
+
+ private static final int compression = 100;
+
+ private static List generateStream() {
+ List li = new ArrayList<>();
+ for (double i = 1D; i <= size; i++) {
+ li.add(i);
+ }
+ Collections.shuffle(li);
+ return li;
+ }
+
+ @Test
+ public void globally() {
+ PCollection> col = tp.apply(Create.of(stream))
+ .apply(TDigestQuantiles.globally(compression))
+ .apply(ParDo.of(new RetrieveQuantiles(0.25, 0.5, 0.75, 0.99)));
+
+ PAssert.that("Verify Accuracy", col).satisfies(new VerifyAccuracy());
+ tp.run();
+ }
+
+ @Test
+ public void perKey() {
+ PCollection> col = tp.apply(Create.of(stream))
+ .apply(WithKeys.of(1))
+ .apply(TDigestQuantiles.perKey(compression))
+ .apply(Values.create())
+ .apply(ParDo.of(new RetrieveQuantiles(0.25, 0.5, 0.75, 0.99)));
+
+ PAssert.that("Verify Accuracy", col).satisfies(new VerifyAccuracy());
+
+ tp.run();
+ }
+
+ @Test
+ public void testCoder() throws Exception {
+ SerializableTDigest tDigest = new SerializableTDigest(1000);
+ for (int i = 0; i < 10; i++) {
+ tDigest.getSketch().add((float) (2.4 + i));
+ }
+ Assert.assertTrue("Encode and Decode", encodeDecode(tDigest));
+ }
+
+ static class VerifyAccuracy implements SerializableFunction>, Void> {
+
+ public Void apply(Iterable> input) {
+ for (KV pair : input) {
+ double expectedError = 3D / compression;
+ double expectedValue = pair.getKey() * size;
+ boolean isAccurate = Math.abs(pair.getValue() - expectedValue)
+ / size <= expectedError;
+ Assert.assertTrue("not accurate enough : \nQuantile " + pair.getKey()
+ + " is " + expectedValue + " and not " + pair.getValue(),
+ isAccurate);
+ }
+ return null;
+ }
+ }
+
+ @Rule
+ public ExpectedException thrown = ExpectedException.none();
+
+ @Test
+ public void testNaN() throws IllegalArgumentException {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Cannot add NaN to t-digest");
+
+ SerializableTDigest std1 = new SerializableTDigest(10);
+ SerializableTDigest std2 = new SerializableTDigest(10);
+
+ std1.getSketch().add(std2.getSketch());
+ }
+
+ @Test
+ public void testMergeAccum() {
+ Random rd = new Random(1234);
+ List accums = new ArrayList<>();
+ for (int i = 0; i < 3; i++) {
+ SerializableTDigest std = new SerializableTDigest(100);
+ for (int j = 0; j < 1000; j++) {
+ std.add(rd.nextDouble());
+ }
+ accums.add(std);
+ }
+ TDigestQuantilesFn fn = TDigestQuantilesFn.create(100);
+ SerializableTDigest res = fn.mergeAccumulators(accums);
+ }
+
+ private boolean encodeDecode(SerializableTDigest tDigest) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ SerializableTDigestCoder tDigestCoder = new SerializableTDigestCoder();
+
+ tDigestCoder.encode(tDigest, baos);
+ byte[] bytes = baos.toByteArray();
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
+ SerializableTDigest decoded = tDigestCoder.decode(bais);
+
+ boolean equal = true;
+ // the only way to compare the two sketches is to compare them centroid by centroid.
+ // Indeed, the means are doubles but are encoded as float and cast during decoding.
+ // This entails a small approximation that makes the centroids different after decoding.
+ Iterator it1 = decoded.getSketch().centroids().iterator();
+ Iterator it2 = tDigest.getSketch().centroids().iterator();
+
+ for (int i = 0; i < decoded.getSketch().centroids().size(); i++) {
+ Centroid c1 = it1.next();
+ Centroid c2 = it2.next();
+ if ((float) c1.mean() != (float) c2.mean() || c1.count() != c2.count()) {
+ equal = false;
+ break;
+ }
+ }
+ return equal;
+ }
+
+ static class RetrieveQuantiles extends DoFn> {
+ private final double quantile;
+ private final double[] otherQ;
+
+ public RetrieveQuantiles(double q, double... otherQ) {
+ this.quantile = q;
+ this.otherQ = otherQ;
+ }
+
+ @ProcessElement public void processElement(ProcessContext c) {
+ c.output(KV.of(quantile, c.element().getSketch().quantile(quantile)));
+ for (Double q : otherQ) {
+ c.output(KV.of(q, c.element().getSketch().quantile(q)));
+ }
+ }
+ }
+}
diff --git a/sdks/java/javadoc/pom.xml b/sdks/java/javadoc/pom.xml
index 1fb6e410818ca..8df8871ed3706 100644
--- a/sdks/java/javadoc/pom.xml
+++ b/sdks/java/javadoc/pom.xml
@@ -92,6 +92,11 @@
beam-sdks-java-extensions-join-library
+