diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java
index 27f101171834b..bc870e440274b 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java
@@ -43,8 +43,8 @@ public RegisterShuffleIndex(
@Override
public boolean equals(Object other) {
- if (other != null && other instanceof UploadShufflePartitionStream) {
- UploadShufflePartitionStream o = (UploadShufflePartitionStream) other;
+ if (other != null && other instanceof RegisterShuffleIndex) {
+ RegisterShuffleIndex o = (RegisterShuffleIndex) other;
return Objects.equal(appId, o.appId)
&& shuffleId == o.shuffleId
&& mapId == o.mapId;
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java
index 374b399621aae..b11a02f6b9219 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java
@@ -43,8 +43,8 @@ public UploadShuffleIndex(
@Override
public boolean equals(Object other) {
- if (other != null && other instanceof UploadShufflePartitionStream) {
- UploadShufflePartitionStream o = (UploadShufflePartitionStream) other;
+ if (other != null && other instanceof UploadShuffleIndex) {
+ UploadShuffleIndex o = (UploadShuffleIndex) other;
return Objects.equal(appId, o.appId)
&& shuffleId == o.shuffleId
&& mapId == o.mapId;
diff --git a/core/pom.xml b/core/pom.xml
index 49b1a54e32598..544ae61279c4d 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -352,6 +352,11 @@
py4j
0.10.8.1
+
+ org.scala-lang.modules
+ scala-java8-compat_${scala.binary.version}
+ 0.9.0
+
org.apache.spark
spark-tags_${scala.binary.version}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java
new file mode 100644
index 0000000000000..7846fad70b159
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java
@@ -0,0 +1,23 @@
+package org.apache.spark.shuffle.api;
+
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.util.Optional;
+
+public interface CommittedPartition {
+
+ /**
+ * Indicates the number of bytes written in a committed partition.
+ * Note that returning the length is mainly for backwards compatibility
+ * and should be removed in a more polished variant. After this method
+ * is called, the writer will be discarded; it's expected that the
+ * implementation will close any underlying resources.
+ */
+ long length();
+
+ /**
+ * Indicates the shuffle location to which this partition was written.
+ * Some implementations may not need to specify a shuffle location.
+ */
+ Optional shuffleLocation();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
index ae9ada03e760d..bdc0fd45474fd 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
@@ -31,12 +31,9 @@ public interface ShufflePartitionWriter {
/**
* Indicate that the partition was written successfully and there are no more incoming bytes.
- * Returns the length of the partition that is written. Note that returning the length is
- * mainly for backwards compatibility and should be removed in a more polished variant.
- * After this method is called, the writer will be discarded; it's expected that the
- * implementation will close any underlying resources.
+ * Returns a {@link CommittedPartition} indicating information about that written partition.
*/
- long commitAndGetTotalLength();
+ CommittedPartition commitPartition();
/**
* Indicate that the write has failed for some reason and the implementation can handle the
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java
new file mode 100644
index 0000000000000..7e37659dbb3f9
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java
@@ -0,0 +1,32 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.spark.shuffle.api.CommittedPartition;
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.util.Optional;
+
+public class ExternalCommittedPartition implements CommittedPartition {
+
+ private final long length;
+ private final Optional shuffleLocation;
+
+ public ExternalCommittedPartition(long length) {
+ this.length = length;
+ this.shuffleLocation = Optional.empty();
+ }
+
+ public ExternalCommittedPartition(long length, ShuffleLocation shuffleLocation) {
+ this.length = length;
+ this.shuffleLocation = Optional.of(shuffleLocation);
+ }
+
+ @Override
+ public long length() {
+ return length;
+ }
+
+ @Override
+ public Optional shuffleLocation() {
+ return shuffleLocation;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java
index 22a1d3336615c..ac20d13de6f2c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java
@@ -1,5 +1,6 @@
package org.apache.spark.shuffle.external;
+import org.apache.spark.MapOutputTracker;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.network.TransportContext;
@@ -20,6 +21,7 @@ public class ExternalShuffleDataIO implements ShuffleDataIO {
private static SecurityManager securityManager;
private static String hostname;
private static int port;
+ private static MapOutputTracker mapOutputTracker;
public ExternalShuffleDataIO(
SparkConf sparkConf) {
@@ -35,14 +37,15 @@ public void initialize() {
securityManager = env.securityManager();
hostname = blockManager.getRandomShuffleHost();
port = blockManager.getRandomShufflePort();
+ mapOutputTracker = env.mapOutputTracker();
// TODO: Register Driver and Executor
}
@Override
public ShuffleReadSupport readSupport() {
return new ExternalShuffleReadSupport(
- conf, context, securityManager.isAuthenticationEnabled(),
- securityManager, hostname, port);
+ conf, context, securityManager.isAuthenticationEnabled(),
+ securityManager, mapOutputTracker);
}
@Override
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java
new file mode 100644
index 0000000000000..20ae8d376050c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java
@@ -0,0 +1,40 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.hadoop.mapreduce.task.reduce.Shuffle;
+import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.io.*;
+
+public class ExternalShuffleLocation implements ShuffleLocation {
+
+ private String shuffleHostname;
+ private int shufflePort;
+
+ public ExternalShuffleLocation() { /* for serialization */ }
+
+ public ExternalShuffleLocation(String shuffleHostname, int shufflePort) {
+ this.shuffleHostname = shuffleHostname;
+ this.shufflePort = shufflePort;
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeUTF(shuffleHostname);
+ out.writeInt(shufflePort);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ this.shuffleHostname = in.readUTF();
+ this.shufflePort = in.readInt();
+ }
+
+ public String getShuffleHostname() {
+ return this.shuffleHostname;
+ }
+
+ public int getShufflePort() {
+ return this.shufflePort;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java
index 34a11ce2b2a32..8866d14feca53 100644
--- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java
@@ -11,6 +11,7 @@
import java.nio.ByteBuffer;
+
public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter {
private final TransportClientFactory clientFactory;
@@ -79,8 +80,8 @@ public void commitAllPartitions() {
logger.info("clientid: " + client.getClientId() + " " + client.isActive());
client.sendRpcSync(uploadShuffleIndex, 60000);
} catch (Exception e) {
- client.close();
logger.error("Encountered error while creating transport client", e);
+ client.close();
throw new RuntimeException(e);
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java
index ee363f8bb41b5..8aefac239e97f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java
@@ -8,7 +8,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.*;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java
index d9b7d7ac515df..89bfe4407e5ac 100644
--- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java
@@ -6,13 +6,16 @@
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream;
+import org.apache.spark.shuffle.api.CommittedPartition;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.apache.spark.storage.ShuffleLocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.Arrays;
+import java.util.Optional;
public class ExternalShufflePartitionWriter implements ShufflePartitionWriter {
@@ -51,7 +54,7 @@ public ExternalShufflePartitionWriter(
public OutputStream openPartitionStream() { return partitionBuffer; }
@Override
- public long commitAndGetTotalLength() {
+ public CommittedPartition commitPartition() {
RpcResponseCallback callback = new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
@@ -88,12 +91,11 @@ public void onFailure(Throwable e) {
} finally {
logger.info("Successfully sent partition to ESS");
}
- return totalLength;
+ return new ExternalCommittedPartition(totalLength, new ExternalShuffleLocation(hostName, port));
}
@Override
public void abort(Exception failureReason) {
- clientFactory.close();
try {
this.partitionBuffer.close();
} catch(IOException e) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java
index 2687c2a4e2379..9e7ff55f47741 100644
--- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java
@@ -1,6 +1,7 @@
package org.apache.spark.shuffle.external;
import com.google.common.collect.Lists;
+import org.apache.spark.MapOutputTracker;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
@@ -9,10 +10,13 @@
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.shuffle.api.ShufflePartitionReader;
import org.apache.spark.shuffle.api.ShuffleReadSupport;
+import org.apache.spark.storage.ShuffleLocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import scala.compat.java8.OptionConverters;
import java.util.List;
+import java.util.Optional;
public class ExternalShuffleReadSupport implements ShuffleReadSupport {
@@ -22,22 +26,19 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport {
private final TransportContext context;
private final boolean authEnabled;
private final SecretKeyHolder secretKeyHolder;
- private final String hostName;
- private final int port;
+ private final MapOutputTracker mapOutputTracker;
public ExternalShuffleReadSupport(
TransportConf conf,
TransportContext context,
boolean authEnabled,
SecretKeyHolder secretKeyHolder,
- String hostName,
- int port) {
+ MapOutputTracker mapOutputTracker) {
this.conf = conf;
this.context = context;
this.authEnabled = authEnabled;
this.secretKeyHolder = secretKeyHolder;
- this.hostName = hostName;
- this.port = port;
+ this.mapOutputTracker = mapOutputTracker;
}
@Override
@@ -47,10 +48,20 @@ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, in
if (authEnabled) {
bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));
}
+ Optional maybeShuffleLocation = OptionConverters.toJava(mapOutputTracker.getShuffleLocation(shuffleId, mapId, 0));
+ assert maybeShuffleLocation.isPresent();
+ ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) maybeShuffleLocation.get();
+ logger.info(String.format("Found external shuffle location on node: %s:%d",
+ externalShuffleLocation.getShuffleHostname(),
+ externalShuffleLocation.getShufflePort()));
TransportClientFactory clientFactory = context.createClientFactory(bootstraps);
try {
return new ExternalShufflePartitionReader(clientFactory,
- hostName, port, appId, shuffleId, mapId);
+ externalShuffleLocation.getShuffleHostname(),
+ externalShuffleLocation.getShufflePort(),
+ appId,
+ shuffleId,
+ mapId);
} catch (Exception e) {
clientFactory.close();
logger.error("Encountered creating transport client for partition reader");
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 823c36d051ddf..26b55aa70387c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -17,24 +17,8 @@
package org.apache.spark.shuffle.sort;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.OutputStream;
-import javax.annotation.Nullable;
-
-import scala.None$;
-import scala.Option;
-import scala.Product2;
-import scala.Tuple2;
-import scala.collection.Iterator;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
@@ -42,14 +26,27 @@
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.api.CommittedPartition;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.ShuffleWriteSupport;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.None$;
+import scala.Option;
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import javax.annotation.Nullable;
+import java.io.*;
+import java.util.Arrays;
+import java.util.stream.Collectors;
/**
* This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
@@ -94,7 +91,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private DiskBlockObjectWriter[] partitionWriters;
private FileSegment[] partitionWriterSegments;
@Nullable private MapStatus mapStatus;
- private long[] partitionLengths;
+ private CommittedPartition[] committedPartitions;
/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
@@ -131,7 +128,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
public void write(Iterator> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
- partitionLengths = new long[numPartitions];
+ long[] partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
@@ -166,25 +163,30 @@ public void write(Iterator> records) throws IOException {
}
if (pluggableWriteSupport != null) {
- partitionLengths = combineAndWritePartitionsUsingPluggableWriter();
+ committedPartitions = combineAndWritePartitionsUsingPluggableWriter();
+ logger.info("Successfully wrote partitions with pluggable writer");
} else {
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
try {
- partitionLengths = combineAndWritePartitions(tmp);
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
+ committedPartitions = combineAndWritePartitions(tmp);
+ logger.info("Successfully wrote partitions without shuffle");
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId,
+ mapId,
+ Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(),
+ tmp);
} finally {
if (tmp != null && tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions);
}
@VisibleForTesting
long[] getPartitionLengths() {
- return partitionLengths;
+ return Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray();
}
/**
@@ -192,12 +194,12 @@ long[] getPartitionLengths() {
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
- private long[] combineAndWritePartitions(File outputFile) throws IOException {
+ private CommittedPartition[] combineAndWritePartitions(File outputFile) throws IOException {
// Track location of the partition starts in the output file
- final long[] lengths = new long[numPartitions];
+ final CommittedPartition[] partitions = new CommittedPartition[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
- return lengths;
+ return partitions;
}
assert(outputFile != null);
final FileOutputStream out = new FileOutputStream(outputFile, true);
@@ -210,7 +212,8 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException {
final FileInputStream in = new FileInputStream(file);
boolean copyThrewException = true;
try {
- lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ partitions[i] =
+ new LocalCommittedPartition(Utils.copyStream(in, out, false, transferToEnabled));
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
@@ -225,15 +228,15 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException {
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
- return lengths;
+ return partitions;
}
- private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOException {
+ private CommittedPartition[] combineAndWritePartitionsUsingPluggableWriter() throws IOException {
// Track location of the partition starts in the output file
- final long[] lengths = new long[numPartitions];
+ final CommittedPartition[] partitions = new CommittedPartition[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
- return lengths;
+ return partitions;
}
assert(pluggableWriteSupport != null);
@@ -251,7 +254,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio
try (OutputStream out = writer.openPartitionStream()) {
Utils.copyStream(in, out, false, false);
}
- lengths[i] = writer.commitAndGetTotalLength();
+ partitions[i] = writer.commitPartition();
copyThrewException = false;
} catch (Exception e) {
try {
@@ -279,7 +282,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
- return lengths;
+ return partitions;
}
@Override
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java
new file mode 100644
index 0000000000000..817855d957966
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java
@@ -0,0 +1,25 @@
+package org.apache.spark.shuffle.sort;
+
+import org.apache.spark.shuffle.api.CommittedPartition;
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.util.Optional;
+
+public class LocalCommittedPartition implements CommittedPartition {
+
+ private final long length;
+
+ public LocalCommittedPartition(long length) {
+ this.length = length;
+ }
+
+ @Override
+ public long length() {
+ return length;
+ }
+
+ @Override
+ public Optional shuffleLocation() {
+ return Optional.empty();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 32be620095110..ef086e21b04d1 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -20,8 +20,13 @@
import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
+import java.util.Arrays;
import java.util.Iterator;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.spark.shuffle.api.CommittedPartition;
+import org.apache.spark.storage.ShuffleLocation;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
@@ -236,12 +241,12 @@ void closeAndWriteOutput() throws IOException {
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
- final long[] partitionLengths;
+ final CommittedPartition[] committedPartitions;
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = Utils.tempFileWith(output);
try {
try {
- partitionLengths = mergeSpills(spills, tmp);
+ committedPartitions = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
@@ -250,14 +255,17 @@ void closeAndWriteOutput() throws IOException {
}
}
if (pluggableWriteSupport == null) {
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId,
+ mapId,
+ Arrays.stream(committedPartitions).mapToLong(CommittedPartition::length).toArray(),
+ tmp);
}
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions);
}
@VisibleForTesting
@@ -289,7 +297,7 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
+ private CommittedPartition[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec =
compressionEnabled ? CompressionCodec$.MODULE$.createCodec(sparkConf) : null;
@@ -301,18 +309,18 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
- return new long[partitioner.numPartitions()];
+ return new CommittedPartition[partitioner.numPartitions()];
} else if (spills.length == 1) {
if (pluggableWriteSupport != null) {
- writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec);
+ return writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec);
} else {
// Here, we don't need to perform any metrics updates because the bytes written to this
// output file would have already been counted as shuffle bytes written.
Files.move(spills[0].file, outputFile);
}
- return spills[0].partitionLengths;
+ return toLocalCommittedPartition(spills[0].partitionLengths);
} else {
- final long[] partitionLengths;
+ final CommittedPartition[] committedPartitions;
// There are multiple spills to merge, so none of these spill files' lengths were counted
// towards our shuffle write count or shuffle write time. If we use the slow merge path,
// then the final output file's size won't necessarily be equal to the sum of the spill
@@ -324,21 +332,21 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
// branch in ExternalSorter.
if (pluggableWriteSupport != null) {
- partitionLengths = mergeSpillsWithPluggableWriter(spills, compressionCodec);
+ committedPartitions = mergeSpillsWithPluggableWriter(spills, compressionCodec);
} else if (fastMergeEnabled && fastMergeIsSupported) {
// Compression is disabled or we are using an IO compression codec that supports
// decompression of concatenated compressed streams, so we can perform a fast spill merge
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
- partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+ committedPartitions = toLocalCommittedPartition(mergeSpillsWithTransferTo(spills, outputFile));
} else {
logger.debug("Using fileStream-based fast merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+ committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, null));
}
} else {
logger.debug("Using slow merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, compressionCodec));
}
// When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
// in-memory records, we write out the in-memory records to a file but do not count that
@@ -349,7 +357,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
if (pluggableWriteSupport == null) {
writeMetrics.incBytesWritten(outputFile.length());
}
- return partitionLengths;
+ return committedPartitions;
}
} catch (IOException e) {
if (outputFile.exists() && !outputFile.delete()) {
@@ -359,6 +367,12 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
}
}
+ private static CommittedPartition[] toLocalCommittedPartition(long[] partitionLengths) {
+ return Arrays.stream(partitionLengths)
+ .mapToObj(length -> new LocalCommittedPartition(length))
+ .collect(Collectors.toList()).toArray(new CommittedPartition[partitionLengths.length]);
+ }
+
/**
* Merges spill files using Java FileStreams. This code path is typically slower than
* the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[],
@@ -512,13 +526,13 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
/**
* Merges spill files using the ShufflePartitionWriter API.
*/
- private long[] mergeSpillsWithPluggableWriter(
+ private CommittedPartition[] mergeSpillsWithPluggableWriter(
SpillInfo[] spills,
@Nullable CompressionCodec compressionCodec) throws IOException {
assert (spills.length >= 2);
assert(pluggableWriteSupport != null);
final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
+ final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions];
final InputStream[] spillInputStreams = new InputStream[spills.length];
boolean threwException = true;
@@ -552,8 +566,8 @@ private long[] mergeSpillsWithPluggableWriter(
}
}
}
- partitionLengths[partition] = writer.commitAndGetTotalLength();
- writeMetrics.incBytesWritten(partitionLengths[partition]);
+ committedPartitions[partition] = writer.commitPartition();
+ writeMetrics.incBytesWritten(committedPartitions[partition].length());
} catch (Exception e) {
try {
writer.abort(e);
@@ -579,14 +593,15 @@ private long[] mergeSpillsWithPluggableWriter(
Closeables.close(stream, threwException);
}
}
- return partitionLengths;
+ return committedPartitions;
}
- private void writeSingleSpillFileUsingPluggableWriter(
+ private CommittedPartition[] writeSingleSpillFileUsingPluggableWriter(
SpillInfo spillInfo,
@Nullable CompressionCodec compressionCodec) throws IOException {
assert(pluggableWriteSupport != null);
final int numPartitions = partitioner.numPartitions();
+ final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions];
boolean threwException = true;
InputStream spillInputStream = new NioBufferedFileInputStream(
spillInfo.file,
@@ -617,7 +632,8 @@ private void writeSingleSpillFileUsingPluggableWriter(
} finally {
partitionInputStream.close();
}
- writeMetrics.incBytesWritten(writer.commitAndGetTotalLength());
+ committedPartitions[partition] = writer.commitPartition();
+ writeMetrics.incBytesWritten(committedPartitions[partition].length());
}
threwException = false;
} catch (Exception e) {
@@ -631,6 +647,7 @@ private void writeSingleSpillFileUsingPluggableWriter(
Closeables.close(spillInputStream, threwException);
}
writeMetrics.decBytesWritten(spillInfo.file.length());
+ return committedPartitions;
}
@Override
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index fb587f02256eb..b340ffbbc43dc 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -35,7 +35,7 @@ import org.apache.spark.internal.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle._
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleLocation}
import org.apache.spark.util._
/**
@@ -303,6 +303,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
+ def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int) : Option[ShuffleLocation]
+
/**
* Deletes map output status information for the specified shuffle stage.
*/
@@ -676,6 +678,14 @@ private[spark] class MapOutputTrackerMaster(
trackerEndpoint = null
shuffleStatuses.clear()
}
+
+ override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int):
+ Option[ShuffleLocation] = {
+ shuffleStatuses.get(shuffleId) match {
+ case Some(shuffleStatus) => shuffleStatus.mapStatuses(mapId).shuffleLocationForBlock(reduceId)
+ case None => Option.empty
+ }
+ }
}
/**
@@ -789,6 +799,14 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
}
}
}
+
+ override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int):
+ Option[ShuffleLocation] = {
+ mapStatuses.get(shuffleId) match {
+ case Some(shuffleStatus) => shuffleStatus(mapId).shuffleLocationForBlock(reduceId)
+ case None => Option.empty
+ }
+ }
}
private[spark] object MapOutputTracker extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 64f0a060a247c..21613a5946f68 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,13 +19,13 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
-import scala.collection.mutable
-
import org.roaringbitmap.RoaringBitmap
+import scala.collection.mutable
import org.apache.spark.SparkEnv
import org.apache.spark.internal.config
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.shuffle.api.CommittedPartition
+import org.apache.spark.storage.{BlockManagerId, ShuffleLocation}
import org.apache.spark.util.Utils
/**
@@ -36,6 +36,8 @@ private[spark] sealed trait MapStatus {
/** Location where this task was run. */
def location: BlockManagerId
+ def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation]
+
/**
* Estimated size for the reduce block, in bytes.
*
@@ -56,11 +58,29 @@ private[spark] object MapStatus {
.map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
.getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
+ def apply(loc: BlockManagerId, committedPartitions: Array[CommittedPartition]): MapStatus = {
+ val shuffleLocationsArray = committedPartitions.collect {
+ case partition if partition != null && partition.shuffleLocation().isPresent
+ => partition.shuffleLocation().get()
+ case _ => null
+ }
+ val lengthsArray = committedPartitions.collect {
+ case partition if partition != null => partition.length()
+ case _ => 0
+
+ }
+ if (committedPartitions.length > minPartitionsToUseHighlyCompressMapStatus) {
+ HighlyCompressedMapStatus(loc, lengthsArray, shuffleLocationsArray)
+ } else {
+ new CompressedMapStatus(loc, lengthsArray, shuffleLocationsArray)
+ }
+ }
+
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation])
} else {
- new CompressedMapStatus(loc, uncompressedSizes)
+ new CompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation])
}
}
@@ -103,17 +123,28 @@ private[spark] object MapStatus {
*/
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
- private[this] var compressedSizes: Array[Byte])
+ private[this] var compressedSizes: Array[Byte],
+ private[this] var shuffleLocations: Array[ShuffleLocation])
extends MapStatus with Externalizable {
- protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+ // For deserialization only
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]], null)
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ def this(loc: BlockManagerId, uncompressedSizes: Array[Long],
+ shuffleLocations: Array[ShuffleLocation]) {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLocations)
}
override def location: BlockManagerId = loc
+ override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = {
+ if (shuffleLocations.apply(reduceId) == null) {
+ Option.empty
+ } else {
+ Option.apply(shuffleLocations.apply(reduceId))
+ }
+ }
+
override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
@@ -122,6 +153,7 @@ private[spark] class CompressedMapStatus(
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
+ out.writeObject(shuffleLocations)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -129,6 +161,7 @@ private[spark] class CompressedMapStatus(
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
+ shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]]
}
}
@@ -148,17 +181,26 @@ private[spark] class HighlyCompressedMapStatus private (
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
- private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte])
+ private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
+ private[this] var shuffleLocations: Array[ShuffleLocation])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1, null) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null, null) // For deserialization only
override def location: BlockManagerId = loc
+ override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = {
+ if (shuffleLocations.apply(reduceId) == null) {
+ Option.empty
+ } else {
+ Option.apply(shuffleLocations.apply(reduceId))
+ }
+ }
+
override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
@@ -180,6 +222,7 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeInt(kv._1)
out.writeByte(kv._2)
}
+ out.writeObject(shuffleLocations)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -195,11 +238,17 @@ private[spark] class HighlyCompressedMapStatus private (
hugeBlockSizesImpl(block) = size
}
hugeBlockSizes = hugeBlockSizesImpl
+ shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]]
}
}
private[spark] object HighlyCompressedMapStatus {
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ apply(loc, uncompressedSizes, Array.empty[ShuffleLocation])
+ }
+
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long],
+ shuffleLocation: Array[ShuffleLocation]): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
@@ -240,6 +289,6 @@ private[spark] object HighlyCompressedMapStatus {
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
- hugeBlockSizes)
+ hugeBlockSizes, shuffleLocation)
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 1c804c99d0e31..98388d80cbe5b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -70,13 +70,16 @@ private[spark] class SortShuffleWriter[K, V, C](
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = pluggableWriteSupport.map { writeSupport =>
+ val committedPartitions = pluggableWriteSupport.map { writeSupport =>
sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport)
}.getOrElse(sorter.writePartitionedFile(blockId, tmp))
if (pluggableWriteSupport.isEmpty) {
- shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId,
+ mapId,
+ committedPartitions.map(_.length()),
+ tmp)
}
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, committedPartitions)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index fb1ed02c857a9..1575b076d3faf 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -18,7 +18,7 @@
package org.apache.spark.storage
import java.io._
-import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference}
+import java.lang.ref.{WeakReference, ReferenceQueue => JReferenceQueue}
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.util.Collections
@@ -31,12 +31,11 @@ import scala.concurrent.duration._
import scala.reflect.ClassTag
import scala.util.Random
import scala.util.control.NonFatal
-
import com.codahale.metrics.{MetricRegistry, MetricSet}
import org.apache.spark._
import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.internal.{Logging, config}
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.metrics.source.Source
import org.apache.spark.network._
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala
new file mode 100644
index 0000000000000..72846cb001c8a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import java.io.Externalizable
+
+trait ShuffleLocation extends Externalizable {
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala
index b03276b2ce16f..baaee46f81237 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
import org.apache.spark.shuffle.ShufflePartitionWriterOutputStream
-import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
+import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, ShufflePartitionWriter}
/**
* Replicates the concept of {@link DiskBlockObjectWriter}, but with some key differences:
@@ -51,15 +51,15 @@ private[spark] class ShufflePartitionObjectWriter(
objectOutputStream = serializerInstance.serializeStream(currentWriterStream)
}
- def commitCurrentPartition(): Long = {
+ def commitCurrentPartition(): CommittedPartition = {
require(objectOutputStream != null, "Cannot commit a partition that has not been started.")
require(currentWriter != null, "Cannot commit a partition that has not been started.")
objectOutputStream.close()
- val length = currentWriter.commitAndGetTotalLength()
+ val committedPartition = currentWriter.commitPartition()
buffer.reset()
currentWriter = null
objectOutputStream = null
- length
+ committedPartition
}
def abortCurrentPartition(throwable: Exception): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 569c8bd092f37..69077c644dc78 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -18,19 +18,19 @@
package org.apache.spark.util.collection
import java.io._
-import java.util.Comparator
+import java.util.{Comparator, Optional}
+import com.google.common.io.ByteStreams
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.ByteStreams
-
-import org.apache.spark.{util, _}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer._
-import org.apache.spark.shuffle.api.ShuffleWriteSupport
-import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShufflePartitionObjectWriter}
+import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport}
+import org.apache.spark.shuffle.sort.LocalCommittedPartition
+import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShuffleLocation, ShufflePartitionObjectWriter}
+import org.apache.spark.{util, _}
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -683,10 +683,10 @@ private[spark] class ExternalSorter[K, V, C](
*/
def writePartitionedFile(
blockId: BlockId,
- outputFile: File): Array[Long] = {
+ outputFile: File): Array[CommittedPartition] = {
// Track location of each range in the output file
- val lengths = new Array[Long](numPartitions)
+ val committedPartitions = new Array[CommittedPartition](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)
@@ -700,7 +700,7 @@ private[spark] class ExternalSorter[K, V, C](
it.writeNext(writer)
}
val segment = writer.commitAndGet()
- lengths(partitionId) = segment.length
+ committedPartitions(partitionId) = new LocalCommittedPartition(segment.length)
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
@@ -710,7 +710,7 @@ private[spark] class ExternalSorter[K, V, C](
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet()
- lengths(id) = segment.length
+ committedPartitions(id) = new LocalCommittedPartition(segment.length)
}
}
}
@@ -720,17 +720,17 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
- lengths
+ committedPartitions
}
/**
* Write all partitions to some backend that is pluggable.
*/
def writePartitionedToExternalShuffleWriteSupport(
- mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[Long] = {
+ mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = {
// Track location of each range in the output file
- val lengths = new Array[Long](numPartitions)
+ val committedPartitions = new Array[CommittedPartition](numPartitions)
val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, shuffleId, mapId)
val writer = new ShufflePartitionObjectWriter(
Math.min(serializerBatchSize, Integer.MAX_VALUE).toInt,
@@ -749,7 +749,7 @@ private[spark] class ExternalSorter[K, V, C](
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
- lengths(partitionId) = writer.commitCurrentPartition()
+ committedPartitions(partitionId) = writer.commitCurrentPartition()
} catch {
case e: Exception =>
util.Utils.tryLogNonFatalError {
@@ -767,7 +767,7 @@ private[spark] class ExternalSorter[K, V, C](
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
- lengths(id) = writer.commitCurrentPartition()
+ committedPartitions(id) = writer.commitCurrentPartition()
} catch {
case e: Exception =>
util.Utils.tryLogNonFatalError {
@@ -791,7 +791,7 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
- lengths
+ committedPartitions
}
def stop(): Unit = {
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 539336cd4fd89..93ab301a4cb9c 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -23,6 +23,7 @@
import java.nio.file.StandardOpenOption;
import java.util.*;
+import org.apache.spark.shuffle.api.CommittedPartition;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
@@ -675,7 +676,7 @@ public OutputStream openPartitionStream() {
}
@Override
- public long commitAndGetTotalLength() {
+ public CommittedPartition commitPartition() {
byte[] partitionBytes = byteBuffer.toByteArray();
try {
Files.write(mergedOutputFile.toPath(), partitionBytes, StandardOpenOption.APPEND);
@@ -684,7 +685,7 @@ public long commitAndGetTotalLength() {
}
int length = partitionBytes.length;
partitionSizesInMergedFile[partitionId] = length;
- return length;
+ return new LocalCommittedPartition(length);
}
@Override
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 21f481d477242..90f6c3523ece8 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark
import scala.collection.mutable.ArrayBuffer
-
import org.mockito.Matchers.any
import org.mockito.Mockito._
@@ -27,7 +26,7 @@ import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleLocation}
class MapOutputTrackerSuite extends SparkFunSuite {
private val conf = new SparkConf
@@ -84,9 +83,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(compressedSize1000, compressedSize10000)))
+ Array[Long](compressedSize1000, compressedSize10000)))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(compressedSize10000, compressedSize1000)))
+ Array[Long](compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
assert(0 == tracker.getNumCachedSerializedBroadcast)
@@ -107,9 +106,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(compressedSize1000, compressedSize1000, compressedSize1000)))
+ Array[Long](compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+ Array[Long](compressedSize10000, compressedSize1000, compressedSize1000)))
assert(0 == tracker.getNumCachedSerializedBroadcast)
// As if we had two simultaneous fetch failures
@@ -260,7 +259,8 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerShuffle(20, 100)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
+ BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0),
+ Array.empty[ShuffleLocation]))
}
val senderAddress = RpcAddress("localhost", 12345)
val rpcCallContext = mock(classOf[RpcCallContext])
diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala
index 3a68fded945b3..097d1e406dc04 100644
--- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala
+++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala
@@ -19,9 +19,11 @@ package org.apache.spark
import java.io._
import java.nio.file.Paths
+import java.util.Optional
import javax.ws.rs.core.UriBuilder
import org.apache.spark.shuffle.api._
+import org.apache.spark.storage.ShuffleLocation
import org.apache.spark.util.Utils
class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO {
@@ -49,8 +51,14 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO {
new FileOutputStream(shuffleFile)
}
- override def commitAndGetTotalLength(): Long =
- resolvePartitionFile(appId, shuffleId, mapId, partitionId).length
+ override def commitPartition(): CommittedPartition = {
+ new CommittedPartition {
+ override def length(): Long =
+ resolvePartitionFile(appId, shuffleId, mapId, partitionId).length
+
+ override def shuffleLocation(): Optional[ShuffleLocation] = Optional.empty()
+ }
+ }
override def abort(failureReason: Exception): Unit = {}
}
@@ -64,7 +72,6 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO {
private def resolvePartitionFile(
appId: String, shuffleId: Int, mapId: Int, reduceId: Int): File = {
- import java.io.OutputStream
Paths.get(UriBuilder.fromUri(shuffleDir.toURI)
.path(appId)
.path(shuffleId.toString)
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 467e49026a029..2e1950bbc7f57 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -26,7 +26,6 @@ import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.reflect.ClassTag
-
import com.esotericsoftware.kryo.{Kryo, KryoException}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import org.roaringbitmap.RoaringBitmap
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala
index f3c9e3e2741f2..b2a4fd42c60fa 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala
@@ -54,6 +54,6 @@ private[spark] class YarnClusterManager extends ExternalClusterManager {
override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend)
}
- override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider =
+ def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider =
DefaultShuffleServiceAddressProvider
}