From be8ca4e78e81136322f95e1a6501716b8007961f Mon Sep 17 00:00:00 2001 From: summaryzb Date: Fri, 4 Apr 2025 06:49:46 +0800 Subject: [PATCH] reoptimize --- .../spark/shuffle/writer/AddBlockEvent.java | 27 +++- .../spark/shuffle/writer/DataPusher.java | 19 ++- .../shuffle/writer/WriteBufferManager.java | 45 ++++-- .../manager/RssShuffleManagerBase.java | 3 +- .../spark/shuffle/writer/DataPusherTest.java | 4 +- .../writer/WriteBufferManagerTest.java | 7 +- .../shuffle/writer/RssShuffleWriter.java | 9 +- .../shuffle/writer/RssShuffleWriter.java | 57 ++++--- .../client/impl/ShuffleWriteClientImpl.java | 149 +++++++++--------- .../uniffle/client/util/ClientUtils.java | 47 ++++-- .../uniffle/client/ClientUtilsTest.java | 21 +-- .../common/netty/client/TransportClient.java | 3 + .../impl/grpc/ShuffleServerGrpcClient.java | 12 +- .../grpc/ShuffleServerGrpcNettyClient.java | 11 +- 14 files changed, 248 insertions(+), 166 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java index f989fdb0b1..cd99e26ac4 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java @@ -19,22 +19,31 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Future; +import java.util.function.Consumer; import org.apache.uniffle.common.ShuffleBlockInfo; public class AddBlockEvent { + private Long eventId; private String taskId; private int stageAttemptNumber; private List shuffleDataInfoList; private List processedCallbackChain; + private Consumer prepare; + public AddBlockEvent(String taskId, List shuffleDataInfoList) { - this(taskId, 0, shuffleDataInfoList); + this(-1L, taskId, 0, shuffleDataInfoList); } public AddBlockEvent( - String taskId, int stageAttemptNumber, List shuffleDataInfoList) { + Long eventId, + String taskId, + int stageAttemptNumber, + List shuffleDataInfoList) { + this.eventId = eventId; this.taskId = taskId; this.stageAttemptNumber = stageAttemptNumber; this.shuffleDataInfoList = shuffleDataInfoList; @@ -46,10 +55,24 @@ public void addCallback(Runnable callback) { processedCallbackChain.add(callback); } + public void addPrepare(Consumer prepare) { + this.prepare = prepare; + } + + public void doPrepare(Future future) { + if (prepare != null) { + prepare.accept(future); + } + } + public String getTaskId() { return taskId; } + public Long getEventId() { + return eventId; + } + public int getStageAttemptNumber() { return stageAttemptNumber; } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java index c55216d261..ba525c8d47 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -24,8 +24,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -80,11 +81,12 @@ public DataPusher( ThreadUtils.getThreadFactory(this.getClass().getName())); } - public CompletableFuture send(AddBlockEvent event) { + public Future send(AddBlockEvent event) { if (rssAppId == null) { throw new RssException("RssAppId should be set."); } - return CompletableFuture.supplyAsync( + FutureTask future = + new FutureTask( () -> { String taskId = event.getTaskId(); List shuffleBlockInfoList = event.getShuffleDataInfoList(); @@ -116,14 +118,11 @@ public CompletableFuture send(AddBlockEvent event) { .filter(x -> succeedBlockIds.contains(x.getBlockId())) .map(x -> x.getFreeMemory()) .reduce((a, b) -> a + b) - .get(); - }, - executorService) - .exceptionally( - ex -> { - LOGGER.error("Unexpected exceptions occurred while sending shuffle data", ex); - return null; + .orElseGet(() -> 0L); }); + event.doPrepare(future); + executorService.submit(future); + return future; } private Set getSucceedBlockIds(SendShuffleDataResult result) { diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index cf4b4bc51c..c7aa68f60a 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -23,7 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; @@ -70,6 +70,8 @@ public class WriteBufferManager extends MemoryConsumer { private AtomicLong recordCounter = new AtomicLong(0); /** An atomic counter used to keep track of the number of blocks */ private AtomicLong blockCounter = new AtomicLong(0); + + private AtomicLong eventIdGenerator = new AtomicLong(0); // it's part of blockId private Map partitionToSeqNo = Maps.newHashMap(); private long askExecutorMemory; @@ -96,7 +98,7 @@ public class WriteBufferManager extends MemoryConsumer { private long requireMemoryInterval; private int requireMemoryRetryMax; private Optional codec; - private Function, List>> spillFunc; + private Function, List>> spillFunc; private long sendSizeLimit; private boolean memorySpillEnabled; private int memorySpillTimeoutSec; @@ -138,7 +140,7 @@ public WriteBufferManager( TaskMemoryManager taskMemoryManager, ShuffleWriteMetrics shuffleWriteMetrics, RssConf rssConf, - Function, List>> spillFunc, + Function, List>> spillFunc, Function> partitionAssignmentRetrieveFunc) { this( shuffleId, @@ -163,7 +165,7 @@ public WriteBufferManager( TaskMemoryManager taskMemoryManager, ShuffleWriteMetrics shuffleWriteMetrics, RssConf rssConf, - Function, List>> spillFunc, + Function, List>> spillFunc, Function> partitionAssignmentRetrieveFunc, int stageAttemptNumber) { super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP); @@ -212,7 +214,7 @@ public WriteBufferManager( TaskMemoryManager taskMemoryManager, ShuffleWriteMetrics shuffleWriteMetrics, RssConf rssConf, - Function, List>> spillFunc, + Function, List>> spillFunc, int stageAttemptNumber) { this( shuffleId, @@ -528,7 +530,12 @@ public List buildBlockEvents(List shuffleBlockI + totalSize + " bytes"); } - events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent)); + events.add( + new AddBlockEvent( + eventIdGenerator.incrementAndGet(), + taskId, + stageAttemptNumber, + shuffleBlockInfosPerEvent)); shuffleBlockInfosPerEvent = Lists.newArrayList(); totalSize = 0; } @@ -543,7 +550,12 @@ public List buildBlockEvents(List shuffleBlockI + " bytes"); } // Use final temporary variables for closures - events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent)); + events.add( + new AddBlockEvent( + eventIdGenerator.incrementAndGet(), + taskId, + stageAttemptNumber, + shuffleBlockInfosPerEvent)); } return events; } @@ -555,15 +567,19 @@ public long spill(long size, MemoryConsumer trigger) { return 0L; } - List> futures = spillFunc.apply(clear(bufferSpillRatio)); - CompletableFuture allOfFutures = - CompletableFuture.allOf(futures.toArray(new CompletableFuture[futures.size()])); + List> futures = spillFunc.apply(clear(bufferSpillRatio)); + long end = System.currentTimeMillis() + memorySpillTimeoutSec * 1000; try { - allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS); + for (Future f : futures) { + f.get(end - System.currentTimeMillis(), TimeUnit.MILLISECONDS); + } } catch (TimeoutException timeoutException) { // A best effort strategy to wait. // If timeout exception occurs, the underlying tasks won't be cancelled. LOG.warn("[taskId: {}] Spill tasks timeout after {} seconds", taskId, memorySpillTimeoutSec); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("[taskId: {}] Spill interrupted due to kill", taskId); } catch (Exception e) { LOG.warn("[taskId: {}] Failed to spill buffers due to ", taskId, e); } finally { @@ -608,6 +624,10 @@ public long getBlockCount() { return blockCounter.get(); } + public Long getLastEventId() { + return eventIdGenerator.get(); + } + public void freeAllocatedMemory(long freeMemory) { freeMemory(freeMemory); allocatedBytes.addAndGet(-freeMemory); @@ -671,8 +691,7 @@ public void setTaskId(String taskId) { } @VisibleForTesting - public void setSpillFunc( - Function, List>> spillFunc) { + public void setSpillFunc(Function, List>> spillFunc) { this.spillFunc = spillFunc; } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 5bf5331eff..e69e8bf58d 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -30,6 +30,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -1587,7 +1588,7 @@ public Map getTaskToFailedBlockSendTracker() { return taskToFailedBlockSendTracker; } - public CompletableFuture sendData(AddBlockEvent event) { + public Future sendData(AddBlockEvent event) { if (dataPusher != null && event != null) { return dataPusher.send(event); } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java index 080ba1e33f..8d0a200f03 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java @@ -22,8 +22,8 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import java.util.function.Supplier; import com.google.common.collect.Maps; @@ -119,7 +119,7 @@ public void testSendData() throws ExecutionException, InterruptedException { new ShuffleBlockInfo(1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1); AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList(shuffleBlockInfo)); // sync send - CompletableFuture future = dataPusher.send(event); + Future future = dataPusher.send(event); long memoryFree = future.get(); assertEquals(100, memoryFree); assertTrue(taskToSuccessBlockIds.get("taskId").contains(1L)); diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 501b57e444..800a118469 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Stream; @@ -370,7 +371,7 @@ public void spillByOwnTest() { null, 0); - Function, List>> spillFunc = + Function, List>> spillFunc = blocks -> { long sum = 0L; List events = wbm.buildBlockEvents(blocks); @@ -481,7 +482,7 @@ public void spillPartial() { null, 0); - Function, List>> spillFunc = + Function, List>> spillFunc = blocks -> { long sum = 0L; List events = wbm.buildBlockEvents(blocks); @@ -579,7 +580,7 @@ public void spillByOwnWithSparkTaskMemoryManagerTest() { List blockList = new ArrayList<>(); - Function, List>> spillFunc = + Function, List>> spillFunc = blocks -> { blockList.addAll(blocks); long sum = 0L; diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 1cd8113c0b..057f65d846 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -361,8 +360,7 @@ private void checkSentBlockCount() { * * @param shuffleBlockInfoList */ - private List> processShuffleBlockInfos( - List shuffleBlockInfoList) { + private List> processShuffleBlockInfos(List shuffleBlockInfoList) { if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) { shuffleBlockInfoList.stream() .forEach( @@ -390,9 +388,8 @@ private List> processShuffleBlockInfos( // don't send huge block to shuffle server, or there will be OOM if shuffle sever receives data // more than expected - protected List> postBlockEvent( - List shuffleBlockInfoList) { - List> futures = new ArrayList<>(); + protected List> postBlockEvent(List shuffleBlockInfoList) { + List> futures = new ArrayList<>(); for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { futures.add(shuffleManager.sendData(event)); } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index a379e516b3..39a94dc6e2 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -26,13 +26,13 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -84,6 +84,7 @@ import org.apache.uniffle.common.exception.RssSendFailedException; import org.apache.uniffle.common.exception.RssWaitFailedException; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.storage.util.StorageType; import static org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED; @@ -128,8 +129,7 @@ public class RssShuffleWriter extends ShuffleWriter { protected final long taskAttemptId; protected final ShuffleWriteMetrics shuffleWriteMetrics; - - private final BlockingQueue finishEventQueue = new LinkedBlockingQueue<>(); + private final Map inFlightEvent = JavaUtils.newConcurrentMap(); // Will be updated when the reassignment is triggered. private TaskAttemptAssignment taskAttemptAssignment; @@ -415,7 +415,7 @@ public long[] getPartitionLengths() { } @VisibleForTesting - protected List> processShuffleBlockInfos( + protected List> processShuffleBlockInfos( List shuffleBlockInfoList) { if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) { shuffleBlockInfoList.forEach( @@ -440,9 +440,8 @@ protected List> processShuffleBlockInfos( return Collections.emptyList(); } - protected List> postBlockEvent( - List shuffleBlockInfoList) { - List> futures = new ArrayList<>(); + protected List> postBlockEvent(List shuffleBlockInfoList) { + List> futures = new ArrayList<>(); for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { if (blockFailSentRetryEnabled) { // do nothing if failed. @@ -455,13 +454,8 @@ protected List> postBlockEvent( }); } } - event.addCallback( - () -> { - boolean ret = finishEventQueue.add(new Object()); - if (!ret) { - LOG.error("Add event " + event + " to finishEventQueue fail"); - } - }); + event.addCallback(() -> inFlightEvent.remove(event.getEventId())); + event.addPrepare(f -> inFlightEvent.put(event.getEventId(), f)); futures.add(shuffleManager.sendData(event)); } return futures; @@ -475,29 +469,48 @@ protected void internalCheckBlockSendResult() { @VisibleForTesting protected void checkBlockSendResult(Set blockIds) { boolean interrupted = false; + boolean hurryUp = false; try { long remainingMs = sendCheckTimeout; long end = System.currentTimeMillis() + remainingMs; - while (true) { + while (!interrupted) { try { - finishEventQueue.clear(); + LOG.warn("checkBlockSendResult," + blockIds.size() + "," + inFlightEvent.size()); checkDataIfAnyFailure(); Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); blockIds.removeAll(successBlockIds); if (blockIds.isEmpty()) { break; } - if (finishEventQueue.isEmpty()) { - remainingMs = Math.max(end - System.currentTimeMillis(), 0); - Object event = finishEventQueue.poll(remainingMs, TimeUnit.MILLISECONDS); - if (event == null) { - break; + if (!inFlightEvent.isEmpty()) { + if (!hurryUp) { + Future maybeLast = inFlightEvent.get(bufferManager.getLastEventId()); + if (maybeLast != null) { + maybeLast.get(remainingMs, TimeUnit.MILLISECONDS); + } + hurryUp = true; } + remainingMs = end - System.currentTimeMillis(); + inFlightEvent.values().stream() + .filter(f -> !f.isDone()) + .findAny() + .orElseGet(() -> CompletableFuture.completedFuture(0L)) + .get(remainingMs, TimeUnit.MILLISECONDS); + } else { + LOG.warn("blockSize:" + blockIds.size() + ",inflightEvent:" + inFlightEvent.size()); + // it seems never reach here, since `blockIds.isEmpty()` will break the loop first + break; } } catch (InterruptedException e) { + LOG.warn("Ignore the InterruptedException which should be caused by internal killed"); interrupted = true; + inFlightEvent.values().stream().forEach(f -> f.cancel(true)); + Thread.currentThread().interrupt(); + } catch (ExecutionException | TimeoutException e) { + LOG.error("check err", e); + break; } } if (!blockIds.isEmpty()) { diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index 3d3bb925df..5f2d5914dc 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -25,9 +25,10 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -179,87 +180,76 @@ private boolean sendShuffleDataAsync( } // If one or more servers is failed, the sending is not totally successful. - List> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); for (Map.Entry>>> entry : serverToBlocks.entrySet()) { - CompletableFuture future = - CompletableFuture.supplyAsync( - () -> { - if (needCancelRequest.get()) { - LOG.info("The upstream task has been failed. Abort this data send."); - return true; + FutureTask future = + new FutureTask( + () -> { + if (needCancelRequest.get()) { + LOG.info("The upstream task has been failed. Abort this data send."); + return true; + } + ShuffleServerInfo ssi = entry.getKey(); + try { + Map>> shuffleIdToBlocks = + entry.getValue(); + // todo: compact unnecessary blocks that reach replicaWrite + RssSendShuffleDataRequest request = + new RssSendShuffleDataRequest( + appId, stageAttemptNumber, retryMax, retryIntervalMax, shuffleIdToBlocks); + long s = System.currentTimeMillis(); + RssSendShuffleDataResponse response = + getShuffleServerClient(ssi).sendShuffleData(request); + + String logMsg = + String.format( + "ShuffleWriteClientImpl sendShuffleData with %s blocks to %s cost: %s(ms)", + serverToBlockIds.get(ssi).size(), + ssi.getId(), + System.currentTimeMillis() - s); + + if (response.getStatusCode() == StatusCode.SUCCESS) { + // mark a replica of block that has been sent + serverToBlockIds + .get(ssi) + .forEach( + blockId -> blockIdsSendSuccessTracker.get(blockId).incrementAndGet()); + recordNeedSplitPartition( + failedBlockSendTracker, ssi, response.getNeedSplitPartitionIds()); + if (defectiveServers != null) { + defectiveServers.remove(ssi); + } + if (LOG.isDebugEnabled()) { + LOG.debug("{} successfully.", logMsg); } - ShuffleServerInfo ssi = entry.getKey(); - try { - Map>> shuffleIdToBlocks = - entry.getValue(); - // todo: compact unnecessary blocks that reach replicaWrite - RssSendShuffleDataRequest request = - new RssSendShuffleDataRequest( - appId, - stageAttemptNumber, - retryMax, - retryIntervalMax, - shuffleIdToBlocks); - long s = System.currentTimeMillis(); - RssSendShuffleDataResponse response = - getShuffleServerClient(ssi).sendShuffleData(request); - - String logMsg = - String.format( - "ShuffleWriteClientImpl sendShuffleData with %s blocks to %s cost: %s(ms)", - serverToBlockIds.get(ssi).size(), - ssi.getId(), - System.currentTimeMillis() - s); - - if (response.getStatusCode() == StatusCode.SUCCESS) { - // mark a replica of block that has been sent - serverToBlockIds - .get(ssi) - .forEach( - blockId -> - blockIdsSendSuccessTracker.get(blockId).incrementAndGet()); - recordNeedSplitPartition( - failedBlockSendTracker, ssi, response.getNeedSplitPartitionIds()); - if (defectiveServers != null) { - defectiveServers.remove(ssi); - } - if (LOG.isDebugEnabled()) { - LOG.debug("{} successfully.", logMsg); - } - } else { - recordFailedBlocks( - failedBlockSendTracker, serverToBlocks, ssi, response.getStatusCode()); - if (defectiveServers != null) { - defectiveServers.add(ssi); - } - LOG.warn( - "{}, it failed wth statusCode[{}]", logMsg, response.getStatusCode()); - return false; - } - } catch (Exception e) { - recordFailedBlocks( - failedBlockSendTracker, serverToBlocks, ssi, StatusCode.INTERNAL_ERROR); - if (defectiveServers != null) { - defectiveServers.add(ssi); - } - LOG.warn( - "Send: " - + serverToBlockIds.get(ssi).size() - + " blocks to [" - + ssi.getId() - + "] failed.", - e); - return false; + } else { + recordFailedBlocks( + failedBlockSendTracker, serverToBlocks, ssi, response.getStatusCode()); + if (defectiveServers != null) { + defectiveServers.add(ssi); } - return true; - }, - dataTransferPool) - .exceptionally( - ex -> { - LOG.error("Unexpected exceptions occurred while sending shuffle data", ex); + LOG.warn("{}, it failed wth statusCode[{}]", logMsg, response.getStatusCode()); return false; - }); + } + } catch (Exception e) { + recordFailedBlocks( + failedBlockSendTracker, serverToBlocks, ssi, StatusCode.INTERNAL_ERROR); + if (defectiveServers != null) { + defectiveServers.add(ssi); + } + LOG.warn( + "Send: " + + serverToBlockIds.get(ssi).size() + + " blocks to [" + + ssi.getId() + + "] failed.", + e); + return false; + } + return true; + }); + dataTransferPool.submit(future); futures.add(future); } @@ -443,7 +433,10 @@ public SendShuffleDataResult sendShuffleData( // Even though the secondary round may send blocks more than replicaWrite replicas, // we do not apply complicated skipping logic, because server crash is rare in production // environment. - if (!isAllSuccess && !secondaryServerToBlocks.isEmpty() && !needCancelRequest.get()) { + if (!isAllSuccess + && !secondaryServerToBlocks.isEmpty() + && !needCancelRequest.get() + && !Thread.currentThread().isInterrupted()) { LOG.info("The sending of primary round is failed partially, so start the secondary round"); sendShuffleDataAsync( appId, diff --git a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java index 29fc4b241c..1f855be0eb 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java +++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java @@ -17,21 +17,26 @@ package org.apache.uniffle.client.util; -import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedList; import java.util.List; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.storage.util.StorageType; public class ClientUtils { + private static final Logger LOG = LoggerFactory.getLogger(ClientUtils.class); public static RemoteStorageInfo fetchRemoteStorage( String appId, @@ -54,42 +59,50 @@ public static RemoteStorageInfo fetchRemoteStorage( } @SuppressWarnings("rawtypes") - public static boolean waitUntilDoneOrFail( - List> futures, boolean allowFastFail) { - int expected = futures.size(); + public static boolean waitUntilDoneOrFail(List> list, boolean allowFastFail) { + int expected = list.size(); int failed = 0; + int finished = 0; + List> futures = new LinkedList<>(list); - CompletableFuture allFutures = - CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); - - List finished = new ArrayList<>(); while (true) { - for (Future future : futures) { - if (future.isDone() && !finished.contains(future)) { - finished.add(future); + Iterator> iterator = futures.iterator(); + while (iterator.hasNext()) { + Future future = iterator.next(); + if (future.isDone()) { + iterator.remove(); try { if (!future.get()) { failed++; + } else { + finished++; } } catch (Exception e) { + // cancel or execution exception failed++; } } } - if (expected == finished.size()) { + if (expected == finished || futures.isEmpty()) { return failed <= 0; } if (failed > 0 && allowFastFail) { - futures.stream().filter(x -> !x.isDone()).forEach(x -> x.cancel(true)); + futures.forEach(x -> x.cancel(true)); return false; } - try { - allFutures.get(10, TimeUnit.MILLISECONDS); - } catch (Exception e) { + futures.get(0).get(10, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + futures.forEach(x -> x.cancel(true)); + Thread.currentThread().interrupt(); + return false; + } catch (TimeoutException e) { // ignore + } catch (Exception e) { + LOG.warn("Exception in waitUntilDoneOrFail", e); + // ignore timeout or execution err } } } diff --git a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java index 6bedd5f334..88bb7e35c7 100644 --- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java +++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java @@ -19,9 +19,10 @@ import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.TimeUnit; import org.awaitility.Awaitility; @@ -70,12 +71,12 @@ public void testGenerateTaskIdBitMap() { } } - private List> getFutures(boolean fail) { - List> futures = new ArrayList<>(); + private List> getFutures(boolean fail) { + List> futures = new ArrayList<>(); for (int i = 0; i < 3; i++) { final int index = i; - CompletableFuture future = - CompletableFuture.supplyAsync( + FutureTask future = + new FutureTask( () -> { if (index == 2) { try { @@ -88,8 +89,8 @@ private List> getFutures(boolean fail) { return true; } return !fail || index != 1; - }, - executorService); + }); + executorService.submit(future); futures.add(future); } return futures; @@ -98,13 +99,13 @@ private List> getFutures(boolean fail) { @Test public void testWaitUntilDoneOrFail() { // case1: enable fail fast - List> futures1 = getFutures(true); + List> futures1 = getFutures(true); Awaitility.await() .timeout(2, TimeUnit.SECONDS) .until(() -> !waitUntilDoneOrFail(futures1, true)); // case2: disable fail fast - List> futures2 = getFutures(true); + List> futures2 = getFutures(true); try { Awaitility.await() .timeout(2, TimeUnit.SECONDS) @@ -115,7 +116,7 @@ public void testWaitUntilDoneOrFail() { } // case3: all succeed - List> futures3 = getFutures(false); + List> futures3 = getFutures(false); Awaitility.await() .timeout(4, TimeUnit.SECONDS) .until(() -> waitUntilDoneOrFail(futures3, true)); diff --git a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java index acd79a60c4..4310ca02fc 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java @@ -92,6 +92,9 @@ public void onFailure(Throwable e) { sendRpc(message, callback); try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RssException(e); } catch (Exception e) { throw new RssException(e); } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 82993890ef..6d129e12cb 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -650,7 +650,17 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ null, request.getRetryIntervalMax(), maxRetryAttempts, - t -> !(t instanceof OutOfMemoryError) && !(t instanceof NotRetryException)); + t -> + !(t instanceof OutOfMemoryError) + && !(t instanceof NotRetryException) + && !(t instanceof InterruptedException) + && !Thread.currentThread().isInterrupted()); + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("cancel send shuffle data, since interrupted"); + isSuccessful = false; + break; } catch (Throwable throwable) { LOG.warn("Failed to send shuffle data due to ", throwable); isSuccessful = false; diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index e48233ddb6..7f9630ad1b 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -235,7 +235,16 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ null, request.getRetryIntervalMax(), maxRetryAttempts, - t -> !(t instanceof OutOfMemoryError) && !(t instanceof NotRetryException)); + t -> + !(t instanceof OutOfMemoryError) + && !(t instanceof NotRetryException) + && !(t instanceof InterruptedException) + && !Thread.currentThread().isInterrupted()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("cancel send shuffle data, since interrupted"); + isSuccessful = false; + break; } catch (Throwable throwable) { LOG.warn("Failed to send shuffle data due to ", throwable); isSuccessful = false;