diff --git a/.github/workflows/sparkucx-ci.yml b/.github/workflows/sparkucx-ci.yml
index 3f627287..4bf27cae 100755
--- a/.github/workflows/sparkucx-ci.yml
+++ b/.github/workflows/sparkucx-ci.yml
@@ -9,7 +9,7 @@ jobs:
build-sparkucx:
strategy:
matrix:
- spark_version: ["2.1", "2.4", "3.0"]
+ spark_version: ["2.1", "2.4", "3.0", "3.1"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
diff --git a/.github/workflows/sparkucx-release.yml b/.github/workflows/sparkucx-release.yml
index cfa93c58..842a3832 100644
--- a/.github/workflows/sparkucx-release.yml
+++ b/.github/workflows/sparkucx-release.yml
@@ -13,7 +13,7 @@ jobs:
release:
strategy:
matrix:
- spark_version: ["2.1", "2.4", "3.0"]
+ spark_version: ["2.1", "2.4", "3.0", "3.1"]
runs-on: ubuntu-latest
steps:
- name: Checkout code
diff --git a/README.md b/README.md
index b9cb6da3..5e61196f 100755
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@ that are supported by [UCX](https://github.com/openucx/ucx#supported-transports)
This open-source project is developed, maintained and supported by the [UCF consortium](http://www.ucfconsortium.org/).
## Runtime requirements
-* Apache Spark 2.3/2.4/3.0
+* Apache Spark 2.3/2.4/3.0/3.1
* Java 8+
* Installed UCX of version 1.10+, and [UCX supported transport hardware](https://github.com/openucx/ucx#supported-transports).
@@ -34,9 +34,9 @@ to spark (e.g. in $SPARK_HOME/conf/spark-defaults.conf):
```
spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager
```
-For spark-3.0 version add SparkUCX ShuffleIO plugin:
+For spark-3.0 or spark-3.1 versions add SparkUCX ShuffleIO plugin:
```
-spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO
+spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_(3_0|3_1).UcxLocalDiskShuffleDataIO
```
### Build
diff --git a/pom.xml b/pom.xml
index b23b54ac..0bb569ff 100755
--- a/pom.xml
+++ b/pom.xml
@@ -43,6 +43,7 @@ See file LICENSE for terms.
maven-compiler-plugin
+ **/spark_3_1/**
**/spark_3_0/**
**/spark_2_4/**
@@ -53,6 +54,7 @@ See file LICENSE for terms.
scala-maven-plugin
+ **/spark_3_1/**
**/spark_3_0/**
**/spark_2_4/**
@@ -62,7 +64,7 @@ See file LICENSE for terms.
2.1.0
- **/spark_3_0/**, **/spark_2_4/**
+ **/spark_3_1/**, **/spark_3_0/**, **/spark_2_4/**
2.11.12
2.11
@@ -76,6 +78,7 @@ See file LICENSE for terms.
maven-compiler-plugin
+ **/spark_3_1/**
**/spark_3_0/**
**/spark_2_1/**
@@ -86,6 +89,7 @@ See file LICENSE for terms.
scala-maven-plugin
+ **/spark_3_1/**
**/spark_2_1/**
**/spark_3_0/**
@@ -95,13 +99,48 @@ See file LICENSE for terms.
2.4.0
- **/spark_3_0/**, **/spark_2_1/**
+ **/spark_3_1/**, **/spark_3_0/**, **/spark_2_1/**
2.11.12
2.11
spark-3.0
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ **/spark_3_1/**
+ **/spark_2_1/**
+ **/spark_2_4/**
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ **/spark_3_1/**
+ **/spark_2_1/**
+ **/spark_2_4/**
+
+
+
+
+
+
+ 3.0.1
+ 2.12.10
+ 2.12
+ **/spark_3_1/**, **/spark_2_1/**, **/spark_2_4/**
+
+
+
+ spark-3.1
true
@@ -112,6 +151,7 @@ See file LICENSE for terms.
maven-compiler-plugin
+ **/spark_3_0/**
**/spark_2_1/**
**/spark_2_4/**
@@ -122,6 +162,7 @@ See file LICENSE for terms.
scala-maven-plugin
+ **/spark_3_0/**
**/spark_2_1/**
**/spark_2_4/**
@@ -130,10 +171,10 @@ See file LICENSE for terms.
- 3.0.1
+ 3.1.2
2.12.10
2.12
- **/spark_2_1/**, **/spark_2_4/**
+ **/spark_3_0/**, **/spark_2_1/**, **/spark_2_4/**
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/OnOffsetsFetchCallback.java
new file mode 100755
index 00000000..14612111
--- /dev/null
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/OnOffsetsFetchCallback.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
+package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1;
+
+import org.apache.spark.network.shuffle.BlockFetchingListener;
+import org.apache.spark.shuffle.UcxWorkerWrapper;
+import org.apache.spark.shuffle.ucx.UnsafeUtils;
+import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
+import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
+import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.ShuffleBlockBatchId;
+import org.apache.spark.storage.ShuffleBlockId;
+import org.openucx.jucx.UcxUtils;
+import org.openucx.jucx.ucp.UcpEndpoint;
+import org.openucx.jucx.ucp.UcpRemoteKey;
+import org.openucx.jucx.ucp.UcpRequest;
+
+import java.nio.ByteBuffer;
+import java.util.Map;
+
+/**
+ * Callback, called when got all offsets for blocks
+ */
+public class OnOffsetsFetchCallback extends ReducerCallback {
+ private final RegisteredMemory offsetMemory;
+ private final long[] dataAddresses;
+ private Map dataRkeysCache;
+ private final Map mapId2PartitionId;
+
+ public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
+ RegisteredMemory offsetMemory, long[] dataAddresses,
+ Map dataRkeysCache,
+ Map mapId2PartitionId) {
+ super(blockIds, endpoint, listener);
+ this.offsetMemory = offsetMemory;
+ this.dataAddresses = dataAddresses;
+ this.dataRkeysCache = dataRkeysCache;
+ this.mapId2PartitionId = mapId2PartitionId;
+ }
+
+ @Override
+ public void onSuccess(UcpRequest request) {
+ ByteBuffer resultOffset = offsetMemory.getBuffer();
+ long totalSize = 0;
+ int[] sizes = new int[blockIds.length];
+ int offset = 0;
+ long blockOffset;
+ long blockLength;
+ int offsetSize = UnsafeUtils.LONG_SIZE;
+ for (int i = 0; i < blockIds.length; i++) {
+ // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
+ if (blockIds[i] instanceof ShuffleBlockBatchId) {
+ ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i];
+ int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId();
+ blockOffset = resultOffset.getLong(offset * 2 * offsetSize);
+ blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch)
+ - blockOffset;
+ offset += blocksInBatch;
+ } else {
+ blockOffset = resultOffset.getLong(offset * 16);
+ blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset;
+ offset++;
+ }
+
+ assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
+ sizes[i] = (int) blockLength;
+ totalSize += blockLength;
+ dataAddresses[i] += blockOffset;
+ }
+
+ assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
+ mempool.put(offsetMemory);
+ RegisteredMemory blocksMemory = mempool.get((int) totalSize);
+
+ offset = 0;
+ // Submits N fetch blocks requests
+ for (int i = 0; i < blockIds.length; i++) {
+ int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ?
+ mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) :
+ mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId());
+ endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId),
+ UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
+ offset += sizes[i];
+ }
+
+ // Process blocks when all fetched.
+ // Flush guarantees that callback would invoke when all fetch requests will completed.
+ endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
+ }
+}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/UcxShuffleClient.java
new file mode 100755
index 00000000..c83cc3e1
--- /dev/null
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/UcxShuffleClient.java
@@ -0,0 +1,136 @@
+/*
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
+package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.executor.TempShuffleReadMetrics;
+import org.apache.spark.network.shuffle.BlockFetchingListener;
+import org.apache.spark.network.shuffle.BlockStoreClient;
+import org.apache.spark.network.shuffle.DownloadFileManager;
+import org.apache.spark.shuffle.DriverMetadata;
+import org.apache.spark.shuffle.UcxShuffleManager;
+import org.apache.spark.shuffle.UcxWorkerWrapper;
+import org.apache.spark.shuffle.ucx.UnsafeUtils;
+import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
+import org.apache.spark.storage.*;
+import org.openucx.jucx.UcxUtils;
+import org.openucx.jucx.ucp.UcpEndpoint;
+import org.openucx.jucx.ucp.UcpRemoteKey;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.Option;
+
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class UcxShuffleClient extends BlockStoreClient {
+ private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
+ private final UcxWorkerWrapper workerWrapper;
+ private final Map mapId2PartitionId;
+ private final TempShuffleReadMetrics shuffleReadMetrics;
+ private final int shuffleId;
+ final HashMap offsetRkeysCache = new HashMap<>();
+ final HashMap dataRkeysCache = new HashMap<>();
+
+
+ public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper,
+ Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) {
+ this.workerWrapper = workerWrapper;
+ this.shuffleId = shuffleId;
+ this.mapId2PartitionId = mapId2PartitionId;
+ this.shuffleReadMetrics = shuffleReadMetrics;
+ }
+
+ /**
+ * Submits n non blocking fetch offsets to get needed offsets for n blocks.
+ */
+ private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds,
+ RegisteredMemory offsetMemory,
+ long[] dataAddresses) {
+ DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId);
+ long offset = 0;
+ int startReduceId;
+ long size;
+
+ for (int i = 0; i < blockIds.length; i++) {
+ BlockId blockId = blockIds[i];
+ int mapIdpartition;
+
+ if (blockId instanceof ShuffleBlockId) {
+ ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId;
+ mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId());
+ size = 2L * UnsafeUtils.LONG_SIZE;
+ startReduceId = shuffleBlockId.reduceId();
+ } else {
+ ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId;
+ mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId());
+ size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId())
+ * 2L * UnsafeUtils.LONG_SIZE;
+ startReduceId = shuffleBlockBatchId.startReduceId();
+ }
+
+ long offsetAddress = driverMetadata.offsetAddress(mapIdpartition);
+ dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition);
+
+ offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition)));
+
+ dataRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition)));
+
+ endpoint.getNonBlockingImplicit(
+ offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE,
+ offsetRkeysCache.get(mapIdpartition),
+ UcxUtils.getAddress(offsetMemory.getBuffer()) + offset,
+ size);
+
+ offset += size;
+ }
+ }
+
+ @Override
+ public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener,
+ DownloadFileManager downloadFileManager) {
+ long startTime = System.currentTimeMillis();
+ BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
+ UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
+ long[] dataAddresses = new long[blockIds.length];
+ int totalBlocks = 0;
+
+ BlockId[] blocks = new BlockId[blockIds.length];
+
+ for (int i = 0; i < blockIds.length; i++) {
+ blocks[i] = BlockId.apply(blockIds[i]);
+ if (blocks[i] instanceof ShuffleBlockId) {
+ totalBlocks += 1;
+ } else {
+ ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i];
+ totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId());
+ }
+ }
+
+ RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager())
+ .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE);
+ // Submits N implicit get requests without callback
+ submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses);
+
+ // flush guarantees that all that requests completes when callback is called.
+ // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
+ workerWrapper.worker().flushNonBlocking(
+ new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory,
+ dataAddresses, dataRkeysCache, mapId2PartitionId));
+
+ shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
+ }
+
+ @Override
+ public void close() {
+ offsetRkeysCache.values().forEach(UcpRemoteKey::close);
+ dataRkeysCache.values().forEach(UcpRemoteKey::close);
+ logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
+ }
+
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleDataIO.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleDataIO.scala
new file mode 100755
index 00000000..47c6e448
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleDataIO.scala
@@ -0,0 +1,20 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.compat.spark_3_1
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO
+
+/**
+ * Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration.
+ */
+case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging {
+
+ override def executor(): ShuffleExecutorComponents = {
+ new UcxLocalDiskShuffleExecutorComponents(sparkConf)
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleExecutorComponents.scala
new file mode 100755
index 00000000..088377de
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleExecutorComponents.scala
@@ -0,0 +1,47 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.compat.spark_3_1
+
+import java.util
+import java.util.Optional
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter}
+import org.apache.spark.shuffle.UcxShuffleManager
+import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter}
+
+/**
+ * Entry point to UCX executor.
+ */
+class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf)
+ extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging{
+
+ private var blockResolver: UcxShuffleBlockResolver = _
+
+ override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = {
+ val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
+ ucxShuffleManager.startUcxNodeIfMissing()
+ blockResolver = ucxShuffleManager.shuffleBlockResolver
+ }
+
+ override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.")
+ }
+ new LocalDiskShuffleMapOutputWriter(
+ shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf)
+ }
+
+ override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.")
+ }
+ Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver))
+ }
+
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleBlockResolver.scala
new file mode 100755
index 00000000..1fa8c912
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleBlockResolver.scala
@@ -0,0 +1,52 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.compat.spark_3_1
+
+import java.io.{File, RandomAccessFile}
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.network.shuffle.ExecutorDiskUtils
+import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager}
+import org.apache.spark.storage.ShuffleIndexBlockId
+
+/**
+ * Mapper entry point for UcxShuffle plugin. Performs memory registration
+ * of data and index files and publish addresses to driver metadata buffer.
+ */
+class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
+ extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
+
+ override def getIndexFile(
+ shuffleId: Int,
+ mapId: Long,
+ dirs: Option[Array[String]] = None): File = {
+ val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
+ val blockManager = SparkEnv.get.blockManager
+ dirs
+ .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name))
+ .getOrElse(blockManager.diskBlockManager.getFile(blockId))
+ }
+
+ override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
+ lengths: Array[Long], dataTmp: File): Unit = {
+ super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
+ // In Spark-3.0 MapId is long and unique among all jobs in spark. We need to use partitionId as offset
+ // in metadata buffer
+ val partitionId = TaskContext.getPartitionId()
+ val dataFile = getDataFile(shuffleId, mapId)
+ val dataBackFile = new RandomAccessFile(dataFile, "rw")
+
+ if (dataBackFile.length() == 0) {
+ dataBackFile.close()
+ return
+ }
+
+ val indexFile = getIndexFile(shuffleId, mapId)
+ val indexBackFile = new RandomAccessFile(indexFile, "rw")
+
+ writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, dataTmp, indexBackFile, dataBackFile)
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleManager.scala
new file mode 100755
index 00000000..64a55726
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleManager.scala
@@ -0,0 +1,75 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.compat.spark_3_1.{UcxShuffleBlockResolver, UcxShuffleReader}
+import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter}
+import org.apache.spark.util.ShutdownHookManager
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
+
+/**
+ * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
+ * and injects needed logic in override methods.
+ */
+class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) {
+ ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop)
+ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+
+ override val shuffleBlockResolver = new UcxShuffleBlockResolver(this)
+
+ override def registerShuffle[K, V, C](shuffleId: ShuffleId, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ assume(isDriver)
+ val numMaps = dependency.partitioner.numPartitions
+ val baseHandle = super.registerShuffle(shuffleId, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]]
+ registerShuffleCommon(baseHandle, shuffleId, numMaps)
+ }
+
+ override def getWriter[K, V](handle: ShuffleHandle, mapId: Long, context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, V, _]])
+ val env = SparkEnv.get
+ handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle match {
+ case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] =>
+ new UnsafeShuffleWriter(
+ env.blockManager,
+ context.taskMemoryManager(),
+ unsafeShuffleHandle,
+ mapId,
+ context,
+ env.conf,
+ metrics,
+ shuffleExecutorComponents)
+ case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] =>
+ new SortShuffleWriter(
+ shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
+ }
+ }
+
+ override def getReader[K, C](handle: ShuffleHandle, startMapIndex: Int, endMapIndex: Int,
+ startPartition: MapId, endPartition: MapId, context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+
+ startUcxNodeIfMissing()
+ shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, _, C]])
+ new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startMapIndex, endMapIndex, startPartition, endPartition,
+ context, readMetrics = metrics, shouldBatchFetch = true)
+ }
+
+
+ private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
+ val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
+ val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX)
+ .toMap
+ executorComponents.initializeExecutor(
+ conf.getAppId,
+ SparkEnv.get.executorId,
+ extraConfigs.asJava)
+ executorComponents
+ }
+
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleReader.scala
new file mode 100755
index 00000000..6ca31966
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleReader.scala
@@ -0,0 +1,191 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.compat.spark_3_1
+
+import java.io.InputStream
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.internal.{Logging, config}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1.UcxShuffleClient
+import org.apache.spark.shuffle.{ShuffleReadMetricsReporter, ShuffleReader, UcxShuffleHandle, UcxShuffleManager}
+import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId}
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.ExternalSorter
+import org.apache.spark.{InterruptibleIterator, SparkEnv, SparkException, TaskContext}
+
+
+/**
+ * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
+ * and lazy progress only when result queue is empty.
+ */
+class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ readMetrics: ShuffleReadMetricsReporter,
+ shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging {
+
+ private val dep = handle.baseHandle.dependency
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val (blocksByAddressIterator1, blocksByAddressIterator2) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition).duplicate
+ val mapIdToBlockIndex = blocksByAddressIterator2.flatMap{
+ case (_, blocks) => blocks.map {
+ case (blockId, _, mapIdx) => blockId match {
+ case x: ShuffleBlockId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer])
+ case x: ShuffleBlockBatchId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer])
+ case _ => throw new SparkException("Unknown block")
+ }
+ }
+ }.toMap
+
+ val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
+ .ucxNode.getThreadLocalWorker
+ val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+ val shuffleClient = new UcxShuffleClient(handle.shuffleId, workerWrapper, mapIdToBlockIndex.asJava, shuffleMetrics)
+ val shuffleIterator = new ShuffleBlockFetcherIterator(
+ context,
+ shuffleClient,
+ blockManager,
+ blocksByAddressIterator1,
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+ readMetrics,
+ fetchContinuousBlocksInBatch)
+
+ val wrappedStreams = shuffleIterator.toCompletionIterator
+
+ // Ucx shuffle logic
+ // Java reflection to get access to private results queue
+ val queueField = shuffleIterator.getClass.getDeclaredField(
+ "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
+ queueField.setAccessible(true)
+ val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]]
+
+ // Do progress if queue is empty before calling next on ShuffleIterator
+ val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
+ override def next(): (BlockId, InputStream) = {
+ val startTime = System.currentTimeMillis()
+ workerWrapper.fillQueueWithBlocks(resultQueue)
+ readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
+ wrappedStreams.next()
+ }
+
+ override def hasNext: Boolean = {
+ val result = wrappedStreams.hasNext
+ if (!result) {
+ shuffleClient.close()
+ }
+ result
+ }
+ }
+ // End of ucx shuffle logic
+
+ val serializerInstance = dep.serializer.newInstance()
+
+ // Create a key/value iterator for each stream
+ val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+
+ // Update the context task metrics for each record read.
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+ val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ if (dep.mapSideCombine) {
+ // We are reading values that are already combined
+ val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ } else {
+ // We don't know the value type, but also don't care -- the dependency *should*
+ // have made sure its compatible w/ this aggregator, which will convert the value
+ // type to the combined type C
+ val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
+ }
+ } else {
+ interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
+ }
+
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener[Unit](_ => {
+ sorter.stop()
+ })
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
+ case None =>
+ aggregatedIter
+ }
+
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
+ }
+ }
+
+ private def fetchContinuousBlocksInBatch: Boolean = {
+ val conf = SparkEnv.get.conf
+ val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
+ val compressed = conf.get(config.SHUFFLE_COMPRESS)
+ val codecConcatenation = if (compressed) {
+ CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf))
+ } else {
+ true
+ }
+ val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)
+
+ val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
+ (!compressed || codecConcatenation) && !useOldFetchProtocol
+ if (shouldBatchFetch && !doBatchFetch) {
+ logWarning("The feature tag of continuous shuffle block fetching is set to true, but " +
+ "we can not enable the feature because other conditions are not satisfied. " +
+ s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " +
+ s"relocatable: $serializerRelocatable, " +
+ s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " +
+ s"$useOldFetchProtocol.")
+ }
+ doBatchFetch
+ }
+
+}