diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala index 80f8e44e8816..4e17fcc56fc6 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala @@ -317,7 +317,7 @@ private class GlutenOptimizedWriterShuffleReader( case _ => SparkEnv.get.serializerManager } - val wrappedStreams = new ShuffleBlockFetcherIterator( + val wrappedStreams = new GlutenShuffleBlockFetcherIterator( context, SparkEnv.get.blockManager.blockStoreClient, SparkEnv.get.blockManager, diff --git a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala index 3b5fce63f8c2..099c8ea36707 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala @@ -239,7 +239,6 @@ private class ColumnarBatchSerializerInstanceImpl( } numOutputRows += numRowsTotal wrappedOut.close() - streamReader.close() if (cb != null) { cb.close() } diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala index 1e514cf9f1db..e59166aae649 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala @@ -22,7 +22,7 @@ import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, GlutenShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator /** @@ -70,7 +70,7 @@ class ColumnarShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( + val wrappedStreams = new GlutenShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, blockManager, diff --git a/backends-velox/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/backends-velox/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala new file mode 100644 index 000000000000..d29fc48dd82b --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ + +import org.roaringbitmap.RoaringBitmap + +import java.util.concurrent.TimeUnit + +import scala.collection +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch push-merged block meta and shuffle chunks. A push-merged block contains + * multiple shuffle chunks where each shuffle chunk contains multiple shuffle blocks that belong to + * the common reduce partition and were merged by the external shuffle service to that chunk. + */ +private class GlutenPushBasedFetchHelper( + private val iterator: GlutenShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker, + private val shuffleMetrics: ShuffleReadMetricsReporter) + extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, + blockManager.blockManagerId.host, + blockManager.blockManagerId.port, + blockManager.blockManagerId.topologyInfo) + + /** A map for storing shuffle chunk bitmap. */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** Returns true if the address is for a push-merged block. */ + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER == address.executorId + } + + /** Returns true if the address is of a remote push-merged block. false otherwise. */ + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host + } + + /** Returns true if the address is of a push-merged-local block. false otherwise. */ + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta + } + + /** + * Get the RoaringBitMap for a specific ShuffleBlockChunkId + * + * @param blockId + * shuffle chunk id. + */ + def getRoaringBitMap(blockId: ShuffleBlockChunkId): Option[RoaringBitmap] = { + chunksMetaMap.get(blockId) + } + + /** + * Get the number of map blocks in a ShuffleBlockChunk + * @param blockId + * @return + */ + def getShuffleChunkCardinality(blockId: ShuffleBlockChunkId): Int = { + getRoaringBitMap(blockId).map(_.getCardinality).getOrElse(0) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. + * + * @param shuffleId + * shuffle id. + * @param reduceId + * reduce id. + * @param blockSize + * size of the push-merged block. + * @param bitmaps + * chunk bitmaps, where each bitmap contains all the mapIds that were merged to that chunk. + * @return + * shuffle chunks to fetch. + */ + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / bitmaps.length + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- bitmaps.indices) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req + * [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch metadata of + * push-merged blocks. + */ + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + meta: MergedBlockMeta): Unit = { + logDebug( + s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + s" $reduceId) from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue( + PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + sizeMap((shuffleId, reduceId)), + meta.readChunkBitmaps(), + address)) + } catch { + case exception: Exception => + logError( + s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$shuffleMergeId, $reduceId) from" + + s" ${req.address.host}:${req.address.port}", + exception + ) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + + override def onFailure( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + exception: Throwable): Unit = { + logError( + s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", + exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + req.blocks.foreach { + block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId] + shuffleClient.getMergedBlockMeta( + address.host, + address.port, + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + mergedBlocksMetaListener) + } + } + + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks + * set of identified merged local blocks and their sizes. + */ + def fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) + } + } + + /** + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. + */ + private def fetchPushMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedPushedMergedDirs = + hostLocalDirManager.getCachedHostLocalDirsFor(SHUFFLE_MERGER_IDENTIFIER) + if (cachedPushedMergedDirs.isDefined) { + logDebug( + s"Fetch the push-merged-local blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") + pushMergedLocalBlocks.foreach { + blockId => + fetchPushMergedLocalBlock( + blockId, + cachedPushedMergedDirs.get, + localShuffleMergerBlockMgrId) + } + } else { + // Push-based shuffle is only enabled when the external shuffle service is enabled. If the + // external shuffle service is not enabled, then there will not be any push-merged blocks + // for the iterator to fetch. + logDebug( + s"Asynchronous fetch the push-merged-local blocks without cached merged " + + s"dirs from the external shuffle service") + hostLocalDirManager.getHostLocalDirs( + blockManager.blockManagerId.host, + blockManager.externalShuffleServicePort, + Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + logDebug( + s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { + blockId => + logDebug( + s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchPushMergedLocalBlock( + blockId, + dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + case Failure(throwable) => + // If we see an exception with getting the local dirs for push-merged-local blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning( + s"Error while fetching the merged dirs for push-merged-local " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable + ) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult( + blockId, + localShuffleMergerBlockMgrId, + 0, + isNetworkReqDone = false)) + } + } + } + } + + /** + * Fetch a single push-merged-local block generated. This can also be executed by the task thread + * as well as the netty thread. + * @param blockId + * ShuffleBlockId to be fetched + * @param localDirs + * Local directories where the push-merged shuffle files are stored + * @param blockManagerId + * BlockManagerId + */ + private[this] def fetchPushMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Unit = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) + iterator.addToResultsQueue( + PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + chunksMeta.readChunkBitmaps(), + localDirs)) + } catch { + case e: Exception => + // If we see an exception with reading a push-merged-local meta, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while fetching push-merged-local meta, " + + s"prepare to fetch the original blocks", + e) + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + } + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] 2) + * [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] 3) + * [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that failed + * to fetch. It makes a call to the map output tracker to get the list of original blocks for the + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. It also updates the numberOfBlocksToFetch in the iterator as it processes + * failed response and finds more push-merged requests to remote and again updates it with + * additional requests for original blocks. The fallback happens when: + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. + * See fetchLocalBlock. 2. There is a failure when fetching remote shuffle chunks. 3. There + * is a failure when processing SuccessFetchResult which is for a shuffle chunk (local or + * remote). 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle + * chunk (local or remote). + */ + def initiateFallbackFetchForPushMergedBlock(blockId: BlockId, address: BlockManagerId): Unit = { + assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + val fallbackBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = + blockId match { + case shuffleBlockId: ShuffleMergedBlockId => + iterator.decreaseNumBlocksToFetch(1) + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, + shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isRemotePushMergedBlockAddress(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + pendingShuffleChunks.foreach { + pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + iterator.decreaseNumBlocksToFetch(blocksProcessed) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, + shuffleChunkId.reduceId, + chunkBitmap) + } + iterator.fallbackFetch(fallbackBlocksByAddr) + } +} diff --git a/backends-velox/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/backends-velox/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala new file mode 100644 index 000000000000..6ca88ec1fe25 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -0,0 +1,1862 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils +import org.roaringbitmap.RoaringBitmap + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final private[spark] class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean, + clock: Clock = new SystemClock()) + extends Iterator[(BlockId, InputStream)] + with DownloadFileManager + with Logging { + + import GlutenShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** Total number of blocks to fetch. */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer in + * case of a runtime exception when processing the current buffer. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = + new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker, + shuffleMetrics) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile(blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + // Release all current result buffers from all threads + val threadIds = currentResults.keys() + while (threadIds.hasMoreElements) { + val threadId = threadIds.nextElement() + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if ( + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) || + hostLocalBlocks.contains(blockId -> mapIndex) + ) { + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + shuffleMetricsUpdate(blockId, buf, local = false) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug( + "Sending request for %d blocks (%s) from %s" + .format(req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + val requestStartTime = clock.nanoTime() + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks))) + deferredBlocks.clear() + } + } + + @inline def updateMergedReqsDuration(wasReqForMergedChunks: Boolean = false): Unit = { + if (remainingBlocks.isEmpty) { + val durationMs = TimeUnit.NANOSECONDS.toMillis(clock.nanoTime() - requestStartTime) + if (wasReqForMergedChunks) { + shuffleMetrics.incRemoteMergedReqsDuration(durationMs) + } + shuffleMetrics.incRemoteReqsDuration(durationMs) + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + updateMergedReqsDuration(BlockId(blockId).isShuffleChunk) + results.put( + SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo( + s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + updateMergedReqsDuration(wasReqForMergedChunks = true) + results.put( + FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug( + s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo( + s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug( + s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: collection.Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: collection.Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks(localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManager.blockManagerId, + buf.size(), + buf, + false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManagerId, + buf.size(), + buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { + case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug( + s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { + case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug( + s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { + case (bmId, blockInfos) => + blockInfos.forall { + case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug( + s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, + localBlocks, + hostLocalBlocksByExecutor, + pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert( + (0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight + ) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo( + s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + private def shuffleMetricsUpdate(blockId: BlockId, buf: ManagedBuffer, local: Boolean): Unit = { + if (local) { + shuffleLocalMetricsUpdate(blockId, buf) + } else { + shuffleRemoteMetricsUpdate(blockId, buf) + } + } + + private def shuffleLocalMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incLocalMergedChunksFetched(1) + shuffleMetrics.incLocalMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incLocalMergedBytesRead(buf.size) + shuffleMetrics.incLocalBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incLocalBlocksFetched(1) + } + shuffleMetrics.incLocalBytesRead(buf.size) + } + + private def shuffleRemoteMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incRemoteMergedChunksFetched(1) + shuffleMetrics.incRemoteMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incRemoteMergedBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incRemoteBlocksFetched(1) + } + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers underlying each + * InputStream will be freed by the cleanup() method registered with the TaskCompletionListener. + * However, callers should close() these InputStreams as soon as they are no longer needed, in + * order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if ( + hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) + ) { + // It is a host local block or a local shuffle chunk + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + numBlocksInFlightPerAddress(address) -= 1 + shuffleMetricsUpdate(blockId, buf, local = false) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + shuffleMetrics.incCorruptMergedBlockChunks(1) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError( + "Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + } + + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (blockId.isShuffleChunk) { + shuffleMetrics.incCorruptMergedBlockChunks(1) + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, + mapIndex, + address, + e, + Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, + Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) -= request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the external shuffle service. + // In this case, the blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) -= 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = + blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { + case (buf, chunkId) => + buf.retain() + val shuffleChunkId = + ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put( + SuccessFetchResult( + shuffleChunkId, + SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, + buf.size(), + buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", + e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, + pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps, + address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) -= 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) -= 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + address) + // Set result to null to force another iteration. + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + val successResult = result.asInstanceOf[SuccessFetchResult] + val threadId = Thread.currentThread().getId + currentResults.put(threadId, successResult) + ( + successResult.blockId, + new GlutenBufferReleasingInputStream( + input, + this, + successResult.blockId, + successResult.mapIndex, + successResult.address, + detectCorrupt && streamCompressedOrEncrypted, + successResult.isNetworkReqDone, + Option(checkedIn) + )) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked when + * checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the checksum + * of the block. Then, it will raise a synchronized RPC call along with the checksum to ask the + * server(where the corrupted block is fetched from) to diagnose the cause of corruption and + * return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn + * the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address + * the address where the corrupted block is fetched from. + * @param blockId + * the blockId of the corrupted block. + * @return + * The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption( + address.host, + address.port, + address.executorId, + shuffleBlock.shuffleId, + shuffleBlock.mapId, + shuffleBlock.reduceId, + checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case shuffleBlockBatch: ShuffleBlockBatchId => + val diagnosisResponse = s"BlockBatch $shuffleBlockBatch is corrupted " + + s"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + s"ShuffleBlockBatchId" + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw SparkException.internalError( + s"Unexpected type of BlockId, $unexpected", + category = "STORAGE") + } + } + + def toCompletionIterator: Iterator[(BlockId, InputStream)] = { + CompletionIterator[(BlockId, InputStream), this.type]( + this, + onCompleteCallback.onComplete(context)) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while ( + isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front) + ) { + val request = defReqQueue.dequeue() + logDebug( + s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + mapId, + mapIndex, + startReduceId, + msg, + e) + case ShuffleBlockChunkId(shuffleId, _, reduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + SHUFFLE_PUSH_MAP_ID.toLong, + SHUFFLE_PUSH_MAP_ID, + reduceId, + msg, + e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. This is executed by the task thread + * when the `iterator.next()` is invoked and if that initiates fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]) + : Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode( + originalBlocksByAddr, + originalLocalBlocks, + originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert( + originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. This is executed by the task thread when the + * `iterator.next()` is invoked and if that initiates fallback. + * + * @return + * set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { + req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { + _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { + defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and also detects + * stream corruption if streamCompressedOrEncrypted is true + */ +private class GlutenBufferReleasingInputStream( + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: GlutenShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val mapIndex: Int, + private val address: BlockManagerId, + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) + extends InputStream { + private[this] var closed = false + + override def read(): Int = + tryOrFetchFailedException(delegate.read()) + + override def close(): Unit = { + if (!closed) { + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + iterator.maxReqSizeShuffleToMem) + } + closed = true + } + } + } + + override def available(): Int = + tryOrFetchFailedException(delegate.available()) + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = + tryOrFetchFailedException(delegate.skip(n)) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = + tryOrFetchFailedException(delegate.read(b)) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + tryOrFetchFailedException(delegate.read(b, off, len)) + + override def reset(): Unit = tryOrFetchFailedException(delegate.reset()) + + /** + * Execute a block of code that returns a value, close this stream quietly and re-throwing + * IOException as FetchFailedException when detectCorruption is true. This method is only used by + * the `available`, `read` and `skip` methods inside `BufferReleasingInputStream` currently. + */ + private def tryOrFetchFailedException[T](block: => T): T = { + try { + block + } catch { + case e: IOException if detectCorruption => + val diagnosisResponse = checkedInOpt.map { + checkedIn => iterator.diagnoseCorruption(checkedIn, address, blockId) + } + IOUtils.closeQuietly(this) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) + } + } +} + +/** + * A listener to be called at the completion of the ShuffleBlockFetcherIterator + * @param data + * the ShuffleBlockFetcherIterator to process + */ +private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockFetcherIterator) + extends TaskCompletionListener { + + override def onTaskCompletion(context: TaskContext): Unit = { + if (data != null) { + data.cleanup() + // Null out the referent here to make sure we don't keep a reference to this + // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be + // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress) + data = null + } + } + + // Just an alias for onTaskCompletion to avoid confusing + def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) +} + +private[storage] object GlutenShuffleBlockFetcherIterator { + + /** + * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless + * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred + * until the flag is unset (whenever there's a complete fetch request). + */ + val isNettyOOMOnShuffle = new AtomicBoolean(false) + + def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { + if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { + isNettyOOMOnShuffle.compareAndSet(true, false) + } + } + + /** + * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same + * `mapId` can be merged into one block batch. The block batch is specified by a range of + * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For + * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into + * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, + * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). + * + * @param blocks + * blocks to be merged if possible. May contains already merged blocks. + * @param doBatchFetch + * whether to merge blocks. + * @return + * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. + */ + def mergeContinuousShuffleBlockIdsIfNeeded( + blocks: collection.Seq[FetchBlockInfo], + doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { + val result = if (doBatchFetch) { + val curBlocks = new ArrayBuffer[FetchBlockInfo] + val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] + + def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { + val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] + + // The last merged block may comes from the input, and we can merge more blocks + // into it, if the map id is the same. + def shouldMergeIntoPreviousBatchBlockId = + mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId + + val (startReduceId, size) = + if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { + // Remove the previous batch block id as we will add a new one to replace it. + val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) + ( + removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, + removed.size + toBeMerged.map(_.size).sum) + } else { + (startBlockId.reduceId, toBeMerged.map(_.size).sum) + } + + FetchBlockInfo( + ShuffleBlockBatchId( + startBlockId.shuffleId, + startBlockId.mapId, + startReduceId, + toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), + size, + toBeMerged.head.mapIndex + ) + } + + val iter = blocks.iterator + while (iter.hasNext) { + val info = iter.next() + // It's possible that the input block id is already a batch ID. For example, we merge some + // blocks, and then make fetch requests with the merged blocks according to "max blocks per + // request". The last fetch request may be too small, and we give up and put the remaining + // merged blocks back to the input list. + if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { + mergedBlockInfo += info + } else { + if (curBlocks.isEmpty) { + curBlocks += info + } else { + val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] + val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId + if (curBlockId.mapId != currentMapId) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + curBlocks.clear() + } + curBlocks += info + } + } + } + if (curBlocks.nonEmpty) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + } + mergedBlockInfo + } else { + blocks + } + result + } + + /** + * The block information to fetch used in FetchRequest. + * @param blockId + * block id + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + */ + private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address + * remote BlockManager to fetch from. + * @param blocks + * Sequence of the information for blocks to fetch from the same address. + * @param forMergedMetas + * true if this request is for requesting push-merged meta information; false if it is for + * regular or shuffle chunks. + */ + case class FetchRequest( + address: BlockManagerId, + blocks: collection.Seq[FetchBlockInfo], + forMergedMetas: Boolean = false) { + val size = blocks.map(_.size).sum + } + + /** Result of a fetch from a remote block. */ + sealed private[storage] trait FetchResult + + /** + * Result of a fetch from a remote block successfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + * @param address + * BlockManager that the block was fetched from. + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param buf + * `ManagedBuffer` for the content. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. + */ + private[storage] case class SuccessFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) + extends FetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage + * @param address + * BlockManager that the block was attempted to be fetched from + * @param e + * the failure exception + */ + private[storage] case class FailureFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable) + extends FetchResult + + /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ + private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) + extends FetchResult + + /** + * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local + * push-merged block. + * + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. + * + * @param blockId + * block id + * @param address + * BlockManager that the push-merged block was attempted to be fetched from + * @param size + * size of the block, used to update bytesInFlight. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. Used to update + * reqsInFlight. + */ + private[storage] case class FallbackOnPushMergedFailureResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param blockSize + * size of each push-merged block. + * @param bitmaps + * bitmaps for every chunk. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId) + extends FetchResult + + /** + * Result of a failure while fetching the meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFailedFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + address: BlockManagerId) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a push-merged-local block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param bitmaps + * bitmaps for every chunk. + * @param localDirs + * local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String]) + extends FetchResult +} diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 742d0cba7075..62fdfb1c7275 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -39,6 +39,7 @@ #include "jni/JniFileSystem.h" #include "memory/GlutenBufferedInputBuilder.h" #include "operators/functions/SparkExprToSubfieldFilterParser.h" +#include "operators/plannodes/RowVectorStream.h" #include "shuffle/ArrowShuffleDictionaryWriter.h" #include "udf/UdfLoader.h" #include "utils/Exception.h" @@ -47,7 +48,6 @@ #include "velox/connectors/hive/BufferedInputBuilder.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveDataSource.h" -#include "operators/plannodes/RowVectorStream.h" #include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" // @manual #include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" // @manual #include "velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h" @@ -177,7 +177,8 @@ void VeloxBackend::init( {velox::cudf_velox::CudfConfig::kCudfMemoryResource, backendConf_->get(kCudfMemoryResource, kCudfMemoryResourceDefault)}, {velox::cudf_velox::CudfConfig::kCudfMemoryPercent, - backendConf_->get(kCudfMemoryPercent, kCudfMemoryPercentDefault)}}; + backendConf_->get(kCudfMemoryPercent, kCudfMemoryPercentDefault)}, + {velox::cudf_velox::CudfConfig::kCudfFunctionEngine, "spark"}}; auto& cudfConfig = velox::cudf_velox::CudfConfig::getInstance(); cudfConfig.initialize(std::move(options)); velox::cudf_velox::registerCudf(); @@ -317,13 +318,13 @@ void VeloxBackend::initConnector(const std::shared_ptr(kHiveConnectorId, hiveConf, ioExecutor_.get())); - + // Register value-stream connector for runtime iterator-based inputs auto valueStreamDynamicFilterEnabled = backendConf_->get(kValueStreamDynamicFilterEnabled, kValueStreamDynamicFilterEnabledDefault); velox::connector::registerConnector( std::make_shared(kIteratorConnectorId, hiveConf, valueStreamDynamicFilterEnabled)); - + #ifdef GLUTEN_ENABLE_GPU if (backendConf_->get(kCudfEnableTableScan, kCudfEnableTableScanDefault) && backendConf_->get(kCudfEnabled, kCudfEnabledDefault)) { diff --git a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc index 441cb0d06679..fe2c682036a8 100644 --- a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc +++ b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc @@ -20,19 +20,21 @@ #include #include +#include "memory/GpuBufferColumnarBatch.h" #include "memory/VeloxColumnarBatch.h" #include "shuffle/Payload.h" #include "shuffle/Utils.h" +#include "utils/CachedBatchQueue.h" #include "utils/Common.h" #include "utils/Macros.h" #include "utils/Timer.h" -#include "memory/GpuBufferColumnarBatch.h" #include using namespace facebook::velox; namespace gluten { + namespace { arrow::Result readBlockType(arrow::io::InputStream* inputStream) { @@ -63,62 +65,99 @@ VeloxGpuHashShuffleReaderDeserializer::VeloxGpuHashShuffleReaderDeserializer( readerBufferSize_(readerBufferSize), memoryManager_(memoryManager), deserializeTime_(deserializeTime), - decompressTime_(decompressTime) {} - -bool VeloxGpuHashShuffleReaderDeserializer::resolveNextBlockType() { - GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get())); - switch (blockType) { - case BlockType::kEndOfStream: - return false; - case BlockType::kPlainPayload: - return true; - default: - throw GlutenException(fmt::format("Unsupported block type: {}", static_cast(blockType))); - } - return true; -} + decompressTime_(decompressTime) { + batchQueue_ = std::make_unique>(1L << 30); -void VeloxGpuHashShuffleReaderDeserializer::loadNextStream() { - if (reachedEos_) { - return; - } + const size_t numThreads = std::max(1u, std::thread::hardware_concurrency()); + activeReaders_.store(numThreads); + LOG(WARNING) << "Using " << numThreads << " threads for deserialization"; - auto in = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); - if (in == nullptr) { - reachedEos_ = true; - return; + // Create multiple reader threads + readerThreads_.reserve(numThreads); + for (size_t i = 0; i < numThreads; ++i) { + readerThreads_.emplace_back([this]() { read(); }); } - - GLUTEN_ASSIGN_OR_THROW( - in_, - arrow::io::BufferedInputStream::Create( - readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), std::move(in))); } -std::shared_ptr VeloxGpuHashShuffleReaderDeserializer::next() { - if (in_ == nullptr) { - loadNextStream(); +VeloxGpuHashShuffleReaderDeserializer::~VeloxGpuHashShuffleReaderDeserializer() { + decompressTime_ += decompressTimeCounter_.load(std::memory_order_relaxed); + deserializeTime_ += deserializeTimeCounter_.load(std::memory_order_relaxed); + stopReaders_.store(true, std::memory_order_release); - if (reachedEos_) { - return nullptr; + for (auto& thread : readerThreads_) { + if (thread.joinable()) { + thread.join(); } } +} - while (!resolveNextBlockType()) { - loadNextStream(); +void VeloxGpuHashShuffleReaderDeserializer::read() { + std::shared_ptr inputStream = nullptr; + + while (!stopReaders_.load(std::memory_order_acquire)) { + if (inputStream == nullptr) { + std::lock_guard lockGuard(mtx_); + auto rawStream = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); + if (rawStream == nullptr) { + // No more streams available. + break; + } + + GLUTEN_ASSIGN_OR_THROW( + inputStream, + arrow::io::BufferedInputStream::Create( + readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), std::move(rawStream))); + } + + GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(inputStream.get())); - if (reachedEos_) { - return nullptr; + if (blockType == BlockType::kEndOfStream) { + GLUTEN_THROW_NOT_OK(inputStream->Close()); + inputStream = nullptr; + continue; } + + if (blockType != BlockType::kPlainPayload) { + throw GlutenException(fmt::format("Unsupported block type: {}", static_cast(blockType))); + } + + uint32_t numRows = 0; + int64_t localDeserializeTime = 0; + int64_t localDecompressTime = 0; + + GLUTEN_ASSIGN_OR_THROW( + auto arrowBuffers, + BlockPayload::deserialize( + inputStream.get(), + codec_, + memoryManager_->defaultArrowMemoryPool(), + numRows, + localDeserializeTime, + localDecompressTime)); + + deserializeTimeCounter_.fetch_add(localDeserializeTime, std::memory_order_relaxed); + decompressTimeCounter_.fetch_add(localDecompressTime, std::memory_order_relaxed); + + auto batch = + std::make_shared(rowType_, std::move(arrowBuffers), static_cast(numRows)); + + // Put batch into queue. + batchQueue_->put(batch); + } + + // Close input stream if it's still open. + if (inputStream != nullptr) { + GLUTEN_THROW_NOT_OK(inputStream->Close()); } - uint32_t numRows = 0; - GLUTEN_ASSIGN_OR_THROW( - auto arrowBuffers, - BlockPayload::deserialize( - in_.get(), codec_, memoryManager_->defaultArrowMemoryPool(), numRows, deserializeTime_, decompressTime_)); + // Decrement active reader count. + if (activeReaders_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + batchQueue_->noMoreBatches(); + } +} - return std::make_shared(rowType_, std::move(arrowBuffers), static_cast(numRows)); +std::shared_ptr VeloxGpuHashShuffleReaderDeserializer::next() { + return batchQueue_->get(); } } // namespace gluten diff --git a/cpp/velox/shuffle/VeloxGpuShuffleReader.h b/cpp/velox/shuffle/VeloxGpuShuffleReader.h index 498b28b62259..07dacc10ca53 100644 --- a/cpp/velox/shuffle/VeloxGpuShuffleReader.h +++ b/cpp/velox/shuffle/VeloxGpuShuffleReader.h @@ -17,19 +17,24 @@ #pragma once +#include "memory/GpuBufferColumnarBatch.h" #include "memory/VeloxMemoryManager.h" #include "shuffle/Payload.h" #include "shuffle/ShuffleReader.h" +#include "utils/CachedBatchQueue.h" -#include "velox/serializers/PrestoSerializer.h" #include "velox/type/Type.h" #include "velox/vector/ComplexVector.h" +#include +#include +#include +#include + namespace gluten { /// Convert the buffers to cudf table. -/// Add a lock after reader produces the Vector, relase the lock after the thread processes all the batches. -/// After move the shuffle read operation to gpu, move the lock to start read. +/// Multi-threaded deserializer that uses producer threads to pre-fetch and deserialize batches. class VeloxGpuHashShuffleReaderDeserializer final : public ColumnarBatchIterator { public: VeloxGpuHashShuffleReaderDeserializer( @@ -42,12 +47,13 @@ class VeloxGpuHashShuffleReaderDeserializer final : public ColumnarBatchIterator int64_t& deserializeTime, int64_t& decompressTime); + ~VeloxGpuHashShuffleReaderDeserializer() override; + std::shared_ptr next() override; private: - bool resolveNextBlockType(); - - void loadNextStream(); + // Reader thread function that deserializes batches. + void read(); std::shared_ptr streamReader_; std::shared_ptr schema_; @@ -59,9 +65,14 @@ class VeloxGpuHashShuffleReaderDeserializer final : public ColumnarBatchIterator int64_t& deserializeTime_; int64_t& decompressTime_; - std::shared_ptr in_{nullptr}; + std::atomic deserializeTimeCounter_{0}; + std::atomic decompressTimeCounter_{0}; + + std::vector readerThreads_; + std::unique_ptr> batchQueue_; + std::atomic stopReaders_{false}; + std::atomic activeReaders_{0}; - bool reachedEos_{false}; - bool blockTypeResolved_{false}; + std::mutex mtx_; }; } // namespace gluten diff --git a/cpp/velox/shuffle/VeloxShuffleReader.cc b/cpp/velox/shuffle/VeloxShuffleReader.cc index e5435aee52dd..686c172d9171 100644 --- a/cpp/velox/shuffle/VeloxShuffleReader.cc +++ b/cpp/velox/shuffle/VeloxShuffleReader.cc @@ -21,10 +21,12 @@ #include #include +#include "memory/ColumnarBatch.h" #include "memory/VeloxColumnarBatch.h" #include "shuffle/GlutenByteStream.h" #include "shuffle/Payload.h" #include "shuffle/Utils.h" +#include "utils/CachedBatchQueue.h" #include "utils/Common.h" #include "utils/Macros.h" #include "utils/Timer.h" @@ -44,6 +46,7 @@ using namespace facebook::velox; namespace gluten { + namespace { arrow::Result readBlockType(arrow::io::InputStream* inputStream) { @@ -455,91 +458,134 @@ VeloxHashShuffleReaderDeserializer::VeloxHashShuffleReaderDeserializer( readerBufferSize_(readerBufferSize), memoryManager_(memoryManager), deserializeTime_(deserializeTime), - decompressTime_(decompressTime) {} - -bool VeloxHashShuffleReaderDeserializer::resolveNextBlockType() { - GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get())); - switch (blockType) { - case BlockType::kEndOfStream: - return false; - case BlockType::kDictionary: { + decompressTime_(decompressTime) { + batchQueue_ = std::make_unique>(1L << 30); + + const size_t numThreads = std::max(1u, std::thread::hardware_concurrency()); + activeReaders_.store(numThreads); + LOG(WARNING) << "Using " << numThreads << " threads for deserialization"; + + // Create multiple producer threads + readerThreads_.reserve(numThreads); + for (size_t i = 0; i < numThreads; ++i) { + readerThreads_.emplace_back([this]() { read(); }); + } +} + +VeloxHashShuffleReaderDeserializer::~VeloxHashShuffleReaderDeserializer() { + decompressTime_ += decompressTimeCounter_.load(std::memory_order_relaxed); + deserializeTime_ += deserializeTimeCounter_.load(std::memory_order_relaxed); + stopReaders_.store(true, std::memory_order_release); + + for (auto& thread : readerThreads_) { + if (thread.joinable()) { + thread.join(); + } + } +} + +void VeloxHashShuffleReaderDeserializer::read() { + // Each thread has its own stream and dictionary state + std::shared_ptr inputStream = nullptr; + std::vector dictionaryFields; + std::vector dictionaries; + + while (!stopReaders_.load(std::memory_order_acquire)) { + // Load next stream for this thread + if (inputStream == nullptr) { + std::lock_guard lockGuard(mtx_); + auto rawStream = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); + if (rawStream == nullptr) { + // No more streams available + break; + } + + if (readerBufferSize_ > 0) { + GLUTEN_ASSIGN_OR_THROW( + inputStream, + arrow::io::BufferedInputStream::Create( + readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), std::move(rawStream))); + } else { + inputStream = std::move(rawStream); + } + } + + GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(inputStream.get())); + + if (blockType == BlockType::kEndOfStream) { + GLUTEN_THROW_NOT_OK(inputStream->Close()); + inputStream = nullptr; + continue; + } + + // Handle dictionary blocks. + if (blockType == BlockType::kDictionary) { VeloxDictionaryReader reader(rowType_, memoryManager_->getLeafMemoryPool().get(), codec_.get()); - GLUTEN_ASSIGN_OR_THROW(dictionaryFields_, reader.readFields(in_.get())); - GLUTEN_ASSIGN_OR_THROW(dictionaries_, reader.readDictionaries(in_.get(), dictionaryFields_)); + GLUTEN_ASSIGN_OR_THROW(dictionaryFields, reader.readFields(inputStream.get())); + GLUTEN_ASSIGN_OR_THROW(dictionaries, reader.readDictionaries(inputStream.get(), dictionaryFields)); - GLUTEN_ASSIGN_OR_THROW(blockType, readBlockType(in_.get())); + GLUTEN_ASSIGN_OR_THROW(blockType, readBlockType(inputStream.get())); GLUTEN_CHECK(blockType == BlockType::kDictionaryPayload, "Invalid block type for dictionary payload"); - } break; - case BlockType::kDictionaryPayload: { + } else if (blockType == BlockType::kDictionaryPayload) { GLUTEN_CHECK( - !dictionaryFields_.empty() && !dictionaries_.empty(), + !dictionaryFields.empty() && !dictionaries.empty(), "Dictionaries cannot be empty when reading dictionary payload"); - } break; - case BlockType::kPlainPayload: { - if (!dictionaryFields_.empty()) { + } else if (blockType == BlockType::kPlainPayload) { + if (!dictionaryFields.empty()) { // Clear previous dictionaries if the next block is a plain payload. - dictionaryFields_.clear(); - dictionaries_.clear(); + dictionaryFields.clear(); + dictionaries.clear(); } - } break; - default: + } else { throw GlutenException(fmt::format("Unsupported block type: {}", static_cast(blockType))); - } - return true; -} + } -void VeloxHashShuffleReaderDeserializer::loadNextStream() { - if (reachedEos_) { - return; - } + // Deserialize batch. + uint32_t numRows = 0; + int64_t localDeserializeTime = 0; + int64_t localDecompressTime = 0; - auto in = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); - if (in == nullptr) { - reachedEos_ = true; - return; - } - - if (readerBufferSize_ > 0) { GLUTEN_ASSIGN_OR_THROW( - in_, - arrow::io::BufferedInputStream::Create( - readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), std::move(in))); - } else { - in_ = std::move(in); - } -} + auto arrowBuffers, + BlockPayload::deserialize( + inputStream.get(), + codec_, + memoryManager_->defaultArrowMemoryPool(), + numRows, + localDeserializeTime, + localDecompressTime)); -std::shared_ptr VeloxHashShuffleReaderDeserializer::next() { - if (in_ == nullptr) { - loadNextStream(); + // Update atomic timing statistics. + deserializeTimeCounter_.fetch_add(localDeserializeTime, std::memory_order_relaxed); + decompressTimeCounter_.fetch_add(localDecompressTime, std::memory_order_relaxed); - if (reachedEos_) { - return nullptr; - } - } + auto batch = makeColumnarBatch( + rowType_, + numRows, + std::move(arrowBuffers), + dictionaryFields, + dictionaries, + memoryManager_->getLeafMemoryPool().get(), + localDeserializeTime); - while (!resolveNextBlockType()) { - loadNextStream(); + // Put batch into queue. + batchQueue_->put(batch); + } - if (reachedEos_) { - return nullptr; - } + // Close input stream if it's still open. + if (inputStream != nullptr) { + GLUTEN_THROW_NOT_OK(inputStream->Close()); } - uint32_t numRows = 0; - GLUTEN_ASSIGN_OR_THROW( - auto arrowBuffers, - BlockPayload::deserialize( - in_.get(), codec_, memoryManager_->defaultArrowMemoryPool(), numRows, deserializeTime_, decompressTime_)); + // Decrement active producers count. + if (activeReaders_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + // Last producer to finish. + batchQueue_->noMoreBatches(); + } +} - return makeColumnarBatch( - rowType_, - numRows, - std::move(arrowBuffers), - dictionaryFields_, - dictionaries_, - memoryManager_->getLeafMemoryPool().get(), - deserializeTime_); +std::shared_ptr VeloxHashShuffleReaderDeserializer::next() { + return batchQueue_->get(); } VeloxSortShuffleReaderDeserializer::VeloxSortShuffleReaderDeserializer( @@ -565,8 +611,13 @@ VeloxSortShuffleReaderDeserializer::VeloxSortShuffleReaderDeserializer( memoryManager_(memoryManager) {} VeloxSortShuffleReaderDeserializer::~VeloxSortShuffleReaderDeserializer() { - if (auto in = std::dynamic_pointer_cast(in_)) { - decompressTime_ += in->decompressTime(); + if (in_ != nullptr) { + if (auto in = std::dynamic_pointer_cast(in_)) { + decompressTime_ += in->decompressTime(); + } + if (auto status = in_->Close(); !status.ok()) { + LOG(WARNING) << "Input stream is not closed properly. Error: " << status.message(); + } } } @@ -596,6 +647,7 @@ std::shared_ptr VeloxSortShuffleReaderDeserializer::next() { while (cachedRows_ < batchSize_) { GLUTEN_ASSIGN_OR_THROW(auto bytes, in_->Read(sizeof(RowSizeType), &lastRowSize_)); while (bytes == 0) { + GLUTEN_THROW_NOT_OK(in_->Close()); // Current stream has no more data. Try to load the next stream. loadNextStream(); if (reachedEos_) { @@ -737,6 +789,14 @@ VeloxRssSortShuffleReaderDeserializer::VeloxRssSortShuffleReaderDeserializer( serdeOptions_ = {false, veloxCompressionType_}; } +VeloxRssSortShuffleReaderDeserializer::~VeloxRssSortShuffleReaderDeserializer() { + if (arrowIn_ != nullptr) { + if (auto status = arrowIn_->Close(); !status.ok()) { + LOG(WARNING) << "Input stream is not closed properly. Error: " << status.message(); + } + } +} + std::shared_ptr VeloxRssSortShuffleReaderDeserializer::next() { if (in_ == nullptr || !in_->hasNext()) { do { @@ -772,6 +832,9 @@ void VeloxRssSortShuffleReaderDeserializer::loadNextStream() { return; } + if (arrowIn_ != nullptr) { + GLUTEN_THROW_NOT_OK(arrowIn_->Close()); + } arrowIn_ = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); if (arrowIn_ == nullptr) { diff --git a/cpp/velox/shuffle/VeloxShuffleReader.h b/cpp/velox/shuffle/VeloxShuffleReader.h index f30595dde456..9f48299aa611 100644 --- a/cpp/velox/shuffle/VeloxShuffleReader.h +++ b/cpp/velox/shuffle/VeloxShuffleReader.h @@ -20,11 +20,16 @@ #include "shuffle/Payload.h" #include "shuffle/ShuffleReader.h" #include "shuffle/VeloxSortShuffleWriter.h" +#include "utils/CachedBatchQueue.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/type/Type.h" #include "velox/vector/ComplexVector.h" +#include +#include +#include + namespace gluten { class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator { @@ -39,12 +44,13 @@ class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator { int64_t& deserializeTime, int64_t& decompressTime); + ~VeloxHashShuffleReaderDeserializer() override; + std::shared_ptr next() override; private: - bool resolveNextBlockType(); - - void loadNextStream(); + // Reader thread function that deserializes batches. + void read(); std::shared_ptr streamReader_; std::shared_ptr schema_; @@ -56,12 +62,15 @@ class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator { int64_t& deserializeTime_; int64_t& decompressTime_; - std::shared_ptr in_{nullptr}; + std::atomic deserializeTimeCounter_{0}; + std::atomic decompressTimeCounter_{0}; - bool reachedEos_{false}; + std::vector readerThreads_; + std::unique_ptr> batchQueue_; + std::atomic stopReaders_{false}; + std::atomic activeReaders_{0}; - std::vector dictionaryFields_{}; - std::vector dictionaries_{}; + std::mutex mtx_; }; class VeloxSortShuffleReaderDeserializer final : public ColumnarBatchIterator { @@ -128,7 +137,9 @@ class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator { facebook::velox::common::CompressionKind veloxCompressionType, int64_t& deserializeTime); - std::shared_ptr next(); + ~VeloxRssSortShuffleReaderDeserializer() override; + + std::shared_ptr next() override; private: class VeloxInputStream; diff --git a/cpp/velox/utils/CachedBatchQueue.h b/cpp/velox/utils/CachedBatchQueue.h new file mode 100644 index 000000000000..c3be8c557e4a --- /dev/null +++ b/cpp/velox/utils/CachedBatchQueue.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace gluten { + +template +class CachedBatchQueue { + public: + explicit CachedBatchQueue(const int64_t capacity) : capacity_(capacity) {} + + void put(std::shared_ptr batch) { + std::unique_lock lock(mtx_); + const auto batchSize = batch->numBytes(); + + VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity"); + + notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; }); + + queue_.push(std::move(batch)); + totalSize_ += batchSize; + + notEmpty_.notify_one(); + } + + std::shared_ptr get() { + std::unique_lock lock(mtx_); + notEmpty_.wait(lock, [&]() { return noMoreBatches_ || !queue_.empty(); }); + + if (queue_.empty()) { + return nullptr; + } + auto batch = std::move(queue_.front()); + LOG(WARNING) << "Trying to get from cached buffer queue. Queue length: " << queue_.size() + << ", total size in queue: " << totalSize_ << ", current batch size: " << batch->numBytes() + << std::endl; + + queue_.pop(); + totalSize_ -= batch->numBytes(); + + notFull_.notify_one(); + return batch; + } + + void noMoreBatches() { + std::unique_lock lock(mtx_); + noMoreBatches_ = true; + notFull_.notify_all(); + notEmpty_.notify_all(); + } + + int64_t size() const { + return totalSize_; + } + + bool empty() const { + return queue_.empty(); + } + + private: + int64_t capacity_; + int64_t totalSize_{0}; + bool noMoreBatches_{false}; + + std::queue> queue_; + + std::mutex mtx_; + std::condition_variable notEmpty_; + std::condition_variable notFull_; +}; + +} // namespace gluten diff --git a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala index 59a9f9e146cc..6c1b190696bf 100644 --- a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala +++ b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala @@ -23,30 +23,15 @@ import java.io.InputStream case class ShuffleStreamReader(streams: Iterator[(BlockId, InputStream)]) { private val jniStreams = streams.map { case (blockId, in) => - (blockId, JniByteInputStreams.create(in)) + JniByteInputStreams.create(in) } - private var currentStream: JniByteInputStream = _ - // Called from native side to get the next stream. def nextStream(): JniByteInputStream = { - if (currentStream != null) { - currentStream.close() - } - if (!jniStreams.hasNext) { - currentStream = null + if (jniStreams.hasNext) { + jniStreams.next } else { - currentStream = jniStreams.next._2 - } - currentStream - } - - def close(): Unit = { - // The reader may not attempt to read all streams from `nextStream`, so we need to close the - // current stream if it's not null. - if (currentStream != null) { - currentStream.close() - currentStream = null + null } } }