diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java index 27f101171834b..bc870e440274b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java @@ -43,8 +43,8 @@ public RegisterShuffleIndex( @Override public boolean equals(Object other) { - if (other != null && other instanceof UploadShufflePartitionStream) { - UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + if (other != null && other instanceof RegisterShuffleIndex) { + RegisterShuffleIndex o = (RegisterShuffleIndex) other; return Objects.equal(appId, o.appId) && shuffleId == o.shuffleId && mapId == o.mapId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java index 374b399621aae..b11a02f6b9219 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java @@ -43,8 +43,8 @@ public UploadShuffleIndex( @Override public boolean equals(Object other) { - if (other != null && other instanceof UploadShufflePartitionStream) { - UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + if (other != null && other instanceof UploadShuffleIndex) { + UploadShuffleIndex o = (UploadShuffleIndex) other; return Objects.equal(appId, o.appId) && shuffleId == o.shuffleId && mapId == o.mapId; diff --git a/core/pom.xml b/core/pom.xml index 49b1a54e32598..544ae61279c4d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -352,6 +352,11 @@ py4j 0.10.8.1 + + org.scala-lang.modules + scala-java8-compat_${scala.binary.version} + 0.9.0 + org.apache.spark spark-tags_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java new file mode 100644 index 0000000000000..7846fad70b159 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java @@ -0,0 +1,23 @@ +package org.apache.spark.shuffle.api; + +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public interface CommittedPartition { + + /** + * Indicates the number of bytes written in a committed partition. + * Note that returning the length is mainly for backwards compatibility + * and should be removed in a more polished variant. After this method + * is called, the writer will be discarded; it's expected that the + * implementation will close any underlying resources. + */ + long length(); + + /** + * Indicates the shuffle location to which this partition was written. + * Some implementations may not need to specify a shuffle location. + */ + Optional shuffleLocation(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index ae9ada03e760d..bdc0fd45474fd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -31,12 +31,9 @@ public interface ShufflePartitionWriter { /** * Indicate that the partition was written successfully and there are no more incoming bytes. - * Returns the length of the partition that is written. Note that returning the length is - * mainly for backwards compatibility and should be removed in a more polished variant. - * After this method is called, the writer will be discarded; it's expected that the - * implementation will close any underlying resources. + * Returns a {@link CommittedPartition} indicating information about that written partition. */ - long commitAndGetTotalLength(); + CommittedPartition commitPartition(); /** * Indicate that the write has failed for some reason and the implementation can handle the diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java new file mode 100644 index 0000000000000..7e37659dbb3f9 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java @@ -0,0 +1,32 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public class ExternalCommittedPartition implements CommittedPartition { + + private final long length; + private final Optional shuffleLocation; + + public ExternalCommittedPartition(long length) { + this.length = length; + this.shuffleLocation = Optional.empty(); + } + + public ExternalCommittedPartition(long length, ShuffleLocation shuffleLocation) { + this.length = length; + this.shuffleLocation = Optional.of(shuffleLocation); + } + + @Override + public long length() { + return length; + } + + @Override + public Optional shuffleLocation() { + return shuffleLocation; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index 22a1d3336615c..ac20d13de6f2c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -1,5 +1,6 @@ package org.apache.spark.shuffle.external; +import org.apache.spark.MapOutputTracker; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.network.TransportContext; @@ -20,6 +21,7 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { private static SecurityManager securityManager; private static String hostname; private static int port; + private static MapOutputTracker mapOutputTracker; public ExternalShuffleDataIO( SparkConf sparkConf) { @@ -35,14 +37,15 @@ public void initialize() { securityManager = env.securityManager(); hostname = blockManager.getRandomShuffleHost(); port = blockManager.getRandomShufflePort(); + mapOutputTracker = env.mapOutputTracker(); // TODO: Register Driver and Executor } @Override public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( - conf, context, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port); + conf, context, securityManager.isAuthenticationEnabled(), + securityManager, mapOutputTracker); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java new file mode 100644 index 0000000000000..20ae8d376050c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java @@ -0,0 +1,40 @@ +package org.apache.spark.shuffle.external; + +import org.apache.hadoop.mapreduce.task.reduce.Shuffle; +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.storage.ShuffleLocation; + +import java.io.*; + +public class ExternalShuffleLocation implements ShuffleLocation { + + private String shuffleHostname; + private int shufflePort; + + public ExternalShuffleLocation() { /* for serialization */ } + + public ExternalShuffleLocation(String shuffleHostname, int shufflePort) { + this.shuffleHostname = shuffleHostname; + this.shufflePort = shufflePort; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeUTF(shuffleHostname); + out.writeInt(shufflePort); + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + this.shuffleHostname = in.readUTF(); + this.shufflePort = in.readInt(); + } + + public String getShuffleHostname() { + return this.shuffleHostname; + } + + public int getShufflePort() { + return this.shufflePort; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index 34a11ce2b2a32..8866d14feca53 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -11,6 +11,7 @@ import java.nio.ByteBuffer; + public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final TransportClientFactory clientFactory; @@ -79,8 +80,8 @@ public void commitAllPartitions() { logger.info("clientid: " + client.getClientId() + " " + client.isActive()); client.sendRpcSync(uploadShuffleIndex, 60000); } catch (Exception e) { - client.close(); logger.error("Encountered error while creating transport client", e); + client.close(); throw new RuntimeException(e); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index ee363f8bb41b5..8aefac239e97f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -8,7 +8,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.*; +import java.io.ByteArrayInputStream; +import java.io.InputStream; import java.nio.ByteBuffer; import java.util.Arrays; diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index d9b7d7ac515df..89bfe4407e5ac 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -6,13 +6,16 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream; +import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Optional; public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { @@ -51,7 +54,7 @@ public ExternalShufflePartitionWriter( public OutputStream openPartitionStream() { return partitionBuffer; } @Override - public long commitAndGetTotalLength() { + public CommittedPartition commitPartition() { RpcResponseCallback callback = new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { @@ -88,12 +91,11 @@ public void onFailure(Throwable e) { } finally { logger.info("Successfully sent partition to ESS"); } - return totalLength; + return new ExternalCommittedPartition(totalLength, new ExternalShuffleLocation(hostName, port)); } @Override public void abort(Exception failureReason) { - clientFactory.close(); try { this.partitionBuffer.close(); } catch(IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index 2687c2a4e2379..9e7ff55f47741 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -1,6 +1,7 @@ package org.apache.spark.shuffle.external; import com.google.common.collect.Lists; +import org.apache.spark.MapOutputTracker; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; @@ -9,10 +10,13 @@ import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShufflePartitionReader; import org.apache.spark.shuffle.api.ShuffleReadSupport; +import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.compat.java8.OptionConverters; import java.util.List; +import java.util.Optional; public class ExternalShuffleReadSupport implements ShuffleReadSupport { @@ -22,22 +26,19 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { private final TransportContext context; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; - private final String hostName; - private final int port; + private final MapOutputTracker mapOutputTracker; public ExternalShuffleReadSupport( TransportConf conf, TransportContext context, boolean authEnabled, SecretKeyHolder secretKeyHolder, - String hostName, - int port) { + MapOutputTracker mapOutputTracker) { this.conf = conf; this.context = context; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; - this.hostName = hostName; - this.port = port; + this.mapOutputTracker = mapOutputTracker; } @Override @@ -47,10 +48,20 @@ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, in if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } + Optional maybeShuffleLocation = OptionConverters.toJava(mapOutputTracker.getShuffleLocation(shuffleId, mapId, 0)); + assert maybeShuffleLocation.isPresent(); + ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) maybeShuffleLocation.get(); + logger.info(String.format("Found external shuffle location on node: %s:%d", + externalShuffleLocation.getShuffleHostname(), + externalShuffleLocation.getShufflePort())); TransportClientFactory clientFactory = context.createClientFactory(bootstraps); try { return new ExternalShufflePartitionReader(clientFactory, - hostName, port, appId, shuffleId, mapId); + externalShuffleLocation.getShuffleHostname(), + externalShuffleLocation.getShufflePort(), + appId, + shuffleId, + mapId); } catch (Exception e) { clientFactory.close(); logger.error("Encountered creating transport client for partition reader"); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 823c36d051ddf..26b55aa70387c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -17,24 +17,8 @@ package org.apache.spark.shuffle.sort; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import javax.annotation.Nullable; - -import scala.None$; -import scala.Option; -import scala.Product2; -import scala.Tuple2; -import scala.collection.Iterator; - import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; @@ -42,14 +26,27 @@ import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.ShuffleWriteSupport; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.None$; +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import javax.annotation.Nullable; +import java.io.*; +import java.util.Arrays; +import java.util.stream.Collectors; /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path @@ -94,7 +91,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private long[] partitionLengths; + private CommittedPartition[] committedPartitions; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -131,7 +128,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; + long[] partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; @@ -166,25 +163,30 @@ public void write(Iterator> records) throws IOException { } if (pluggableWriteSupport != null) { - partitionLengths = combineAndWritePartitionsUsingPluggableWriter(); + committedPartitions = combineAndWritePartitionsUsingPluggableWriter(); + logger.info("Successfully wrote partitions with pluggable writer"); } else { File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); File tmp = Utils.tempFileWith(output); try { - partitionLengths = combineAndWritePartitions(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + committedPartitions = combineAndWritePartitions(tmp); + logger.info("Successfully wrote partitions without shuffle"); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, + mapId, + Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(), + tmp); } finally { if (tmp != null && tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions); } @VisibleForTesting long[] getPartitionLengths() { - return partitionLengths; + return Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(); } /** @@ -192,12 +194,12 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] combineAndWritePartitions(File outputFile) throws IOException { + private CommittedPartition[] combineAndWritePartitions(File outputFile) throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; + final CommittedPartition[] partitions = new CommittedPartition[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator - return lengths; + return partitions; } assert(outputFile != null); final FileOutputStream out = new FileOutputStream(outputFile, true); @@ -210,7 +212,8 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException { final FileInputStream in = new FileInputStream(file); boolean copyThrewException = true; try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + partitions[i] = + new LocalCommittedPartition(Utils.copyStream(in, out, false, transferToEnabled)); copyThrewException = false; } finally { Closeables.close(in, copyThrewException); @@ -225,15 +228,15 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException { writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; - return lengths; + return partitions; } - private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOException { + private CommittedPartition[] combineAndWritePartitionsUsingPluggableWriter() throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; + final CommittedPartition[] partitions = new CommittedPartition[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator - return lengths; + return partitions; } assert(pluggableWriteSupport != null); @@ -251,7 +254,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio try (OutputStream out = writer.openPartitionStream()) { Utils.copyStream(in, out, false, false); } - lengths[i] = writer.commitAndGetTotalLength(); + partitions[i] = writer.commitPartition(); copyThrewException = false; } catch (Exception e) { try { @@ -279,7 +282,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; - return lengths; + return partitions; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java new file mode 100644 index 0000000000000..817855d957966 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java @@ -0,0 +1,25 @@ +package org.apache.spark.shuffle.sort; + +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public class LocalCommittedPartition implements CommittedPartition { + + private final long length; + + public LocalCommittedPartition(long length) { + this.length = length; + } + + @Override + public long length() { + return length; + } + + @Override + public Optional shuffleLocation() { + return Optional.empty(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 32be620095110..ef086e21b04d1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -20,8 +20,13 @@ import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.util.Arrays; import java.util.Iterator; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.storage.ShuffleLocation; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -236,12 +241,12 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final long[] partitionLengths; + final CommittedPartition[] committedPartitions; final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { try { - partitionLengths = mergeSpills(spills, tmp); + committedPartitions = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && ! spill.file.delete()) { @@ -250,14 +255,17 @@ void closeAndWriteOutput() throws IOException { } } if (pluggableWriteSupport == null) { - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, + mapId, + Arrays.stream(committedPartitions).mapToLong(CommittedPartition::length).toArray(), + tmp); } } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions); } @VisibleForTesting @@ -289,7 +297,7 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private CommittedPartition[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = compressionEnabled ? CompressionCodec$.MODULE$.createCodec(sparkConf) : null; @@ -301,18 +309,18 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; + return new CommittedPartition[partitioner.numPartitions()]; } else if (spills.length == 1) { if (pluggableWriteSupport != null) { - writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec); + return writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec); } else { // Here, we don't need to perform any metrics updates because the bytes written to this // output file would have already been counted as shuffle bytes written. Files.move(spills[0].file, outputFile); } - return spills[0].partitionLengths; + return toLocalCommittedPartition(spills[0].partitionLengths); } else { - final long[] partitionLengths; + final CommittedPartition[] committedPartitions; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -324,21 +332,21 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" // branch in ExternalSorter. if (pluggableWriteSupport != null) { - partitionLengths = mergeSpillsWithPluggableWriter(spills, compressionCodec); + committedPartitions = mergeSpillsWithPluggableWriter(spills, compressionCodec); } else if (fastMergeEnabled && fastMergeIsSupported) { // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + committedPartitions = toLocalCommittedPartition(mergeSpillsWithTransferTo(spills, outputFile)); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, null)); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, compressionCodec)); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -349,7 +357,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti if (pluggableWriteSupport == null) { writeMetrics.incBytesWritten(outputFile.length()); } - return partitionLengths; + return committedPartitions; } } catch (IOException e) { if (outputFile.exists() && !outputFile.delete()) { @@ -359,6 +367,12 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti } } + private static CommittedPartition[] toLocalCommittedPartition(long[] partitionLengths) { + return Arrays.stream(partitionLengths) + .mapToObj(length -> new LocalCommittedPartition(length)) + .collect(Collectors.toList()).toArray(new CommittedPartition[partitionLengths.length]); + } + /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], @@ -512,13 +526,13 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th /** * Merges spill files using the ShufflePartitionWriter API. */ - private long[] mergeSpillsWithPluggableWriter( + private CommittedPartition[] mergeSpillsWithPluggableWriter( SpillInfo[] spills, @Nullable CompressionCodec compressionCodec) throws IOException { assert (spills.length >= 2); assert(pluggableWriteSupport != null); final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; + final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; boolean threwException = true; @@ -552,8 +566,8 @@ private long[] mergeSpillsWithPluggableWriter( } } } - partitionLengths[partition] = writer.commitAndGetTotalLength(); - writeMetrics.incBytesWritten(partitionLengths[partition]); + committedPartitions[partition] = writer.commitPartition(); + writeMetrics.incBytesWritten(committedPartitions[partition].length()); } catch (Exception e) { try { writer.abort(e); @@ -579,14 +593,15 @@ private long[] mergeSpillsWithPluggableWriter( Closeables.close(stream, threwException); } } - return partitionLengths; + return committedPartitions; } - private void writeSingleSpillFileUsingPluggableWriter( + private CommittedPartition[] writeSingleSpillFileUsingPluggableWriter( SpillInfo spillInfo, @Nullable CompressionCodec compressionCodec) throws IOException { assert(pluggableWriteSupport != null); final int numPartitions = partitioner.numPartitions(); + final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions]; boolean threwException = true; InputStream spillInputStream = new NioBufferedFileInputStream( spillInfo.file, @@ -617,7 +632,8 @@ private void writeSingleSpillFileUsingPluggableWriter( } finally { partitionInputStream.close(); } - writeMetrics.incBytesWritten(writer.commitAndGetTotalLength()); + committedPartitions[partition] = writer.commitPartition(); + writeMetrics.incBytesWritten(committedPartitions[partition].length()); } threwException = false; } catch (Exception e) { @@ -631,6 +647,7 @@ private void writeSingleSpillFileUsingPluggableWriter( Closeables.close(spillInputStream, threwException); } writeMetrics.decBytesWritten(spillInfo.file.length()); + return committedPartitions; } @Override diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index fb587f02256eb..b340ffbbc43dc 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -35,7 +35,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle._ -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleLocation} import org.apache.spark.util._ /** @@ -303,6 +303,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int) : Option[ShuffleLocation] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -676,6 +678,14 @@ private[spark] class MapOutputTrackerMaster( trackerEndpoint = null shuffleStatuses.clear() } + + override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int): + Option[ShuffleLocation] = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => shuffleStatus.mapStatuses(mapId).shuffleLocationForBlock(reduceId) + case None => Option.empty + } + } } /** @@ -789,6 +799,14 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } } + + override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int): + Option[ShuffleLocation] = { + mapStatuses.get(shuffleId) match { + case Some(shuffleStatus) => shuffleStatus(mapId).shuffleLocationForBlock(reduceId) + case None => Option.empty + } + } } private[spark] object MapOutputTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247c..21613a5946f68 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,13 +19,13 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} -import scala.collection.mutable - import org.roaringbitmap.RoaringBitmap +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.config -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.api.CommittedPartition +import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} import org.apache.spark.util.Utils /** @@ -36,6 +36,8 @@ private[spark] sealed trait MapStatus { /** Location where this task was run. */ def location: BlockManagerId + def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] + /** * Estimated size for the reduce block, in bytes. * @@ -56,11 +58,29 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + def apply(loc: BlockManagerId, committedPartitions: Array[CommittedPartition]): MapStatus = { + val shuffleLocationsArray = committedPartitions.collect { + case partition if partition != null && partition.shuffleLocation().isPresent + => partition.shuffleLocation().get() + case _ => null + } + val lengthsArray = committedPartitions.collect { + case partition if partition != null => partition.length() + case _ => 0 + + } + if (committedPartitions.length > minPartitionsToUseHighlyCompressMapStatus) { + HighlyCompressedMapStatus(loc, lengthsArray, shuffleLocationsArray) + } else { + new CompressedMapStatus(loc, lengthsArray, shuffleLocationsArray) + } + } + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation]) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation]) } } @@ -103,17 +123,28 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var shuffleLocations: Array[ShuffleLocation]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], null) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], + shuffleLocations: Array[ShuffleLocation]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLocations) } override def location: BlockManagerId = loc + override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = { + if (shuffleLocations.apply(reduceId) == null) { + Option.empty + } else { + Option.apply(shuffleLocations.apply(reduceId)) + } + } + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } @@ -122,6 +153,7 @@ private[spark] class CompressedMapStatus( loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) + out.writeObject(shuffleLocations) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -129,6 +161,7 @@ private[spark] class CompressedMapStatus( val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) + shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]] } } @@ -148,17 +181,26 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte]) + private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], + private[this] var shuffleLocations: Array[ShuffleLocation]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, null) // For deserialization only override def location: BlockManagerId = loc + override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = { + if (shuffleLocations.apply(reduceId) == null) { + Option.empty + } else { + Option.apply(shuffleLocations.apply(reduceId)) + } + } + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -180,6 +222,7 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeInt(kv._1) out.writeByte(kv._2) } + out.writeObject(shuffleLocations) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -195,11 +238,17 @@ private[spark] class HighlyCompressedMapStatus private ( hugeBlockSizesImpl(block) = size } hugeBlockSizes = hugeBlockSizesImpl + shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]] } } private[spark] object HighlyCompressedMapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + apply(loc, uncompressedSizes, Array.empty[ShuffleLocation]) + } + + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], + shuffleLocation: Array[ShuffleLocation]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -240,6 +289,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes) + hugeBlockSizes, shuffleLocation) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 1c804c99d0e31..98388d80cbe5b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,13 +70,16 @@ private[spark] class SortShuffleWriter[K, V, C]( val tmp = Utils.tempFileWith(output) try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = pluggableWriteSupport.map { writeSupport => + val committedPartitions = pluggableWriteSupport.map { writeSupport => sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport) }.getOrElse(sorter.writePartitionedFile(blockId, tmp)) if (pluggableWriteSupport.isEmpty) { - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, + mapId, + committedPartitions.map(_.length()), + tmp) } - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, committedPartitions) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fb1ed02c857a9..1575b076d3faf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import java.io._ -import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} +import java.lang.ref.{WeakReference, ReferenceQueue => JReferenceQueue} import java.nio.ByteBuffer import java.nio.channels.Channels import java.util.Collections @@ -31,12 +31,11 @@ import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal - import com.codahale.metrics.{MetricRegistry, MetricSet} import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.{Logging, config} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala new file mode 100644 index 0000000000000..72846cb001c8a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import java.io.Externalizable + +trait ShuffleLocation extends Externalizable { + +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala index b03276b2ce16f..baaee46f81237 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.serializer.{SerializationStream, SerializerInstance} import org.apache.spark.shuffle.ShufflePartitionWriterOutputStream -import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} +import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, ShufflePartitionWriter} /** * Replicates the concept of {@link DiskBlockObjectWriter}, but with some key differences: @@ -51,15 +51,15 @@ private[spark] class ShufflePartitionObjectWriter( objectOutputStream = serializerInstance.serializeStream(currentWriterStream) } - def commitCurrentPartition(): Long = { + def commitCurrentPartition(): CommittedPartition = { require(objectOutputStream != null, "Cannot commit a partition that has not been started.") require(currentWriter != null, "Cannot commit a partition that has not been started.") objectOutputStream.close() - val length = currentWriter.commitAndGetTotalLength() + val committedPartition = currentWriter.commitPartition() buffer.reset() currentWriter = null objectOutputStream = null - length + committedPartition } def abortCurrentPartition(throwable: Exception): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 569c8bd092f37..69077c644dc78 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -18,19 +18,19 @@ package org.apache.spark.util.collection import java.io._ -import java.util.Comparator +import java.util.{Comparator, Optional} +import com.google.common.io.ByteStreams import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.ByteStreams - -import org.apache.spark.{util, _} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ -import org.apache.spark.shuffle.api.ShuffleWriteSupport -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShufflePartitionObjectWriter} +import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport} +import org.apache.spark.shuffle.sort.LocalCommittedPartition +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShuffleLocation, ShufflePartitionObjectWriter} +import org.apache.spark.{util, _} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -683,10 +683,10 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - outputFile: File): Array[Long] = { + outputFile: File): Array[CommittedPartition] = { // Track location of each range in the output file - val lengths = new Array[Long](numPartitions) + val committedPartitions = new Array[CommittedPartition](numPartitions) val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics().shuffleWriteMetrics) @@ -700,7 +700,7 @@ private[spark] class ExternalSorter[K, V, C]( it.writeNext(writer) } val segment = writer.commitAndGet() - lengths(partitionId) = segment.length + committedPartitions(partitionId) = new LocalCommittedPartition(segment.length) } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -710,7 +710,7 @@ private[spark] class ExternalSorter[K, V, C]( writer.write(elem._1, elem._2) } val segment = writer.commitAndGet() - lengths(id) = segment.length + committedPartitions(id) = new LocalCommittedPartition(segment.length) } } } @@ -720,17 +720,17 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - lengths + committedPartitions } /** * Write all partitions to some backend that is pluggable. */ def writePartitionedToExternalShuffleWriteSupport( - mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[Long] = { + mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = { // Track location of each range in the output file - val lengths = new Array[Long](numPartitions) + val committedPartitions = new Array[CommittedPartition](numPartitions) val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, shuffleId, mapId) val writer = new ShufflePartitionObjectWriter( Math.min(serializerBatchSize, Integer.MAX_VALUE).toInt, @@ -749,7 +749,7 @@ private[spark] class ExternalSorter[K, V, C]( while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(writer) } - lengths(partitionId) = writer.commitCurrentPartition() + committedPartitions(partitionId) = writer.commitCurrentPartition() } catch { case e: Exception => util.Utils.tryLogNonFatalError { @@ -767,7 +767,7 @@ private[spark] class ExternalSorter[K, V, C]( for (elem <- elements) { writer.write(elem._1, elem._2) } - lengths(id) = writer.commitCurrentPartition() + committedPartitions(id) = writer.commitCurrentPartition() } catch { case e: Exception => util.Utils.tryLogNonFatalError { @@ -791,7 +791,7 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - lengths + committedPartitions } def stop(): Unit = { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 539336cd4fd89..93ab301a4cb9c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -23,6 +23,7 @@ import java.nio.file.StandardOpenOption; import java.util.*; +import org.apache.spark.shuffle.api.CommittedPartition; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -675,7 +676,7 @@ public OutputStream openPartitionStream() { } @Override - public long commitAndGetTotalLength() { + public CommittedPartition commitPartition() { byte[] partitionBytes = byteBuffer.toByteArray(); try { Files.write(mergedOutputFile.toPath(), partitionBytes, StandardOpenOption.APPEND); @@ -684,7 +685,7 @@ public long commitAndGetTotalLength() { } int length = partitionBytes.length; partitionSizesInMergedFile[partitionId] = length; - return length; + return new LocalCommittedPartition(length); } @Override diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..90f6c3523ece8 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer - import org.mockito.Matchers.any import org.mockito.Mockito._ @@ -27,7 +26,7 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleLocation} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -84,9 +83,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array[Long](compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array[Long](compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +106,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array[Long](compressedSize1000, compressedSize1000, compressedSize1000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array[Long](compressedSize10000, compressedSize1000, compressedSize1000))) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -260,7 +259,8 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), + Array.empty[ShuffleLocation])) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index 3a68fded945b3..097d1e406dc04 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -19,9 +19,11 @@ package org.apache.spark import java.io._ import java.nio.file.Paths +import java.util.Optional import javax.ws.rs.core.UriBuilder import org.apache.spark.shuffle.api._ +import org.apache.spark.storage.ShuffleLocation import org.apache.spark.util.Utils class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { @@ -49,8 +51,14 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { new FileOutputStream(shuffleFile) } - override def commitAndGetTotalLength(): Long = - resolvePartitionFile(appId, shuffleId, mapId, partitionId).length + override def commitPartition(): CommittedPartition = { + new CommittedPartition { + override def length(): Long = + resolvePartitionFile(appId, shuffleId, mapId, partitionId).length + + override def shuffleLocation(): Optional[ShuffleLocation] = Optional.empty() + } + } override def abort(failureReason: Exception): Unit = {} } @@ -64,7 +72,6 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { private def resolvePartitionFile( appId: String, shuffleId: Int, mapId: Int, reduceId: Int): File = { - import java.io.OutputStream Paths.get(UriBuilder.fromUri(shuffleDir.toURI) .path(appId) .path(shuffleId.toString) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 467e49026a029..2e1950bbc7f57 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -26,7 +26,6 @@ import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag - import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import org.roaringbitmap.RoaringBitmap diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala index f3c9e3e2741f2..b2a4fd42c60fa 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala @@ -54,6 +54,6 @@ private[spark] class YarnClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } - override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = DefaultShuffleServiceAddressProvider }