diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputReader.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputReader.java
index 7210add56a163..efc5cc4fcb281 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputReader.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputReader.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.runners.core;
-import javax.annotation.Nullable;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.PCollectionView;
@@ -32,7 +31,6 @@ public interface SideInputReader {
*
It is valid for a side input to be {@code null}. It is not valid for this to return
* {@code null} for any other reason.
*/
- @Nullable
T get(PCollectionView view, BoundedWindow window);
/** Returns true if the given {@link PCollectionView} is valid for this reader. */
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java
index 163ba55e6d1bd..2c9a77f6f0bee 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java
@@ -26,11 +26,14 @@
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.spark.api.java.JavaSparkContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import scala.Tuple2;
/** SparkPCollectionView is used to pass serialized views to lambdas. */
public class SparkPCollectionView implements Serializable {
+ private static final Logger LOG = LoggerFactory.getLogger(SparkPCollectionView.class);
// Holds the view --> broadcast mapping. Transient so it will be null from resume
private transient volatile Map, SideInputBroadcast> broadcastHelperMap = null;
@@ -85,6 +88,13 @@ private SideInputBroadcast createBroadcastHelper(
PCollectionView> view, JavaSparkContext context) {
Tuple2>>> tuple2 = pviews.get(view);
SideInputBroadcast helper = SideInputBroadcast.create(tuple2._1, tuple2._2);
+ String pCollectionName =
+ view.getPCollection() != null ? view.getPCollection().getName() : "UNKNOWN";
+ LOG.debug(
+ "Broadcasting [size={}B] view {} from pCollection {}",
+ helper.getBroadcastSizeEstimate(),
+ view,
+ pCollectionName);
helper.broadcast(context);
broadcastHelperMap.put(view, helper);
return helper;
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/CachedSideInputReader.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/CachedSideInputReader.java
index 49200c6c4b429..5d2e52141dece 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/CachedSideInputReader.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/CachedSideInputReader.java
@@ -17,17 +17,23 @@
*/
package org.apache.beam.runners.spark.util;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Objects;
+import java.util.concurrent.ExecutionException;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.runners.spark.util.SideInputStorage.Key;
+import org.apache.beam.runners.spark.util.SideInputStorage.Value;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.cache.Cache;
+import org.apache.spark.util.SizeEstimator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/** {@link SideInputReader} that caches materialized views. */
public class CachedSideInputReader implements SideInputReader {
+ private static final Logger LOG = LoggerFactory.getLogger(CachedSideInputReader.class);
+
/**
* Create a new cached {@link SideInputReader}.
*
@@ -38,46 +44,9 @@ public static CachedSideInputReader of(SideInputReader delegate) {
return new CachedSideInputReader(delegate);
}
- /**
- * Composite key of {@link PCollectionView} and {@link BoundedWindow} used to identify
- * materialized results.
- *
- * @param type of result
- */
- private static class Key {
-
- private final PCollectionView view;
- private final BoundedWindow window;
-
- Key(PCollectionView view, BoundedWindow window) {
- this.view = view;
- this.window = window;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- final Key> key = (Key>) o;
- return Objects.equals(view, key.view) && Objects.equals(window, key.window);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(view, window);
- }
- }
-
/** Wrapped {@link SideInputReader} which results will be cached. */
private final SideInputReader delegate;
- /** Materialized results. */
- private final Map, ?> materialized = new HashMap<>();
-
private CachedSideInputReader(SideInputReader delegate) {
this.delegate = delegate;
}
@@ -86,9 +55,28 @@ private CachedSideInputReader(SideInputReader delegate) {
@Override
public T get(PCollectionView view, BoundedWindow window) {
@SuppressWarnings("unchecked")
- final Map, T> materializedCasted = (Map) materialized;
- return materializedCasted.computeIfAbsent(
- new Key<>(view, window), key -> delegate.get(view, window));
+ final Cache, Value> materializedCasted =
+ (Cache) SideInputStorage.getMaterializedSideInputs();
+
+ Key sideInputKey = new Key<>(view, window);
+
+ try {
+ Value cachedResult =
+ materializedCasted.get(
+ sideInputKey,
+ () -> {
+ final T result = delegate.get(view, window);
+ LOG.debug(
+ "Caching de-serialized side input for {} of size [{}B] in memory.",
+ sideInputKey,
+ SizeEstimator.estimate(result));
+
+ return new Value<>(result);
+ });
+ return cachedResult.getValue();
+ } catch (ExecutionException e) {
+ throw new RuntimeException(e.getCause());
+ }
}
@Override
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java
index e0cb3de05771b..f42159d3a7d42 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java
@@ -23,6 +23,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.util.SizeEstimator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -73,4 +74,8 @@ private T deserialize() {
}
return val;
}
+
+ public long getBroadcastSizeEstimate() {
+ return SizeEstimator.estimate(bytes);
+ }
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputStorage.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputStorage.java
new file mode 100644
index 0000000000000..482702f3b3e1b
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputStorage.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.util;
+
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.cache.Cache;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.cache.CacheBuilder;
+
+/**
+ * Cache deserialized side inputs for executor so every task doesn't need to deserialize them again.
+ * Side inputs are stored in {@link Cache} with 5 minutes expireAfterAccess.
+ */
+class SideInputStorage {
+
+ /** JVM deserialized side input cache. */
+ private static final Cache, Value>> materializedSideInputs =
+ CacheBuilder.newBuilder().expireAfterAccess(5, TimeUnit.MINUTES).build();
+
+ static Cache, Value>> getMaterializedSideInputs() {
+ return materializedSideInputs;
+ }
+
+ /**
+ * Composite key of {@link PCollectionView} and {@link BoundedWindow} used to identify
+ * materialized results.
+ *
+ * @param type of result
+ */
+ public static class Key {
+
+ private final PCollectionView view;
+ private final BoundedWindow window;
+
+ Key(PCollectionView view, BoundedWindow window) {
+ this.view = view;
+ this.window = window;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Key> key = (Key>) o;
+ return Objects.equals(view, key.view) && Objects.equals(window, key.window);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(view, window);
+ }
+
+ @Override
+ public String toString() {
+ String pName = view.getPCollection() != null ? view.getPCollection().getName() : "Unknown";
+ return "Key{"
+ + "view="
+ + view.getTagInternal()
+ + " of PCollection["
+ + pName
+ + "], window="
+ + window
+ + '}';
+ }
+ }
+
+ /**
+ * Null value is not allowed in guava's Cache and is valid in SideInput so we use wrapper for
+ * cache value.
+ */
+ public static class Value {
+ final T value;
+
+ Value(T value) {
+ this.value = value;
+ }
+
+ public T getValue() {
+ return value;
+ }
+ }
+}