Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,31 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;
import java.util.function.Consumer;

import org.apache.uniffle.common.ShuffleBlockInfo;

public class AddBlockEvent {

private Long eventId;
private String taskId;
private int stageAttemptNumber;
private List<ShuffleBlockInfo> shuffleDataInfoList;
private List<Runnable> processedCallbackChain;

private Consumer<Future> prepare;

public AddBlockEvent(String taskId, List<ShuffleBlockInfo> shuffleDataInfoList) {
this(taskId, 0, shuffleDataInfoList);
this(-1L, taskId, 0, shuffleDataInfoList);
}

public AddBlockEvent(
String taskId, int stageAttemptNumber, List<ShuffleBlockInfo> shuffleDataInfoList) {
Long eventId,
String taskId,
int stageAttemptNumber,
List<ShuffleBlockInfo> shuffleDataInfoList) {
this.eventId = eventId;
this.taskId = taskId;
this.stageAttemptNumber = stageAttemptNumber;
this.shuffleDataInfoList = shuffleDataInfoList;
Expand All @@ -46,10 +55,24 @@ public void addCallback(Runnable callback) {
processedCallbackChain.add(callback);
}

public void addPrepare(Consumer<Future> prepare) {
this.prepare = prepare;
}

public void doPrepare(Future future) {
if (prepare != null) {
prepare.accept(future);
}
}

public String getTaskId() {
return taskId;
}

public Long getEventId() {
return eventId;
}

public int getStageAttemptNumber() {
return stageAttemptNumber;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -80,11 +81,12 @@ public DataPusher(
ThreadUtils.getThreadFactory(this.getClass().getName()));
}

public CompletableFuture<Long> send(AddBlockEvent event) {
public Future<Long> send(AddBlockEvent event) {
if (rssAppId == null) {
throw new RssException("RssAppId should be set.");
}
return CompletableFuture.supplyAsync(
FutureTask<Long> future =
new FutureTask(
() -> {
String taskId = event.getTaskId();
List<ShuffleBlockInfo> shuffleBlockInfoList = event.getShuffleDataInfoList();
Expand Down Expand Up @@ -116,14 +118,11 @@ public CompletableFuture<Long> send(AddBlockEvent event) {
.filter(x -> succeedBlockIds.contains(x.getBlockId()))
.map(x -> x.getFreeMemory())
.reduce((a, b) -> a + b)
.get();
},
executorService)
.exceptionally(
ex -> {
LOGGER.error("Unexpected exceptions occurred while sending shuffle data", ex);
return null;
.orElseGet(() -> 0L);
});
event.doPrepare(future);
executorService.submit(future);
return future;
}

private Set<Long> getSucceedBlockIds(SendShuffleDataResult result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -70,6 +70,8 @@ public class WriteBufferManager extends MemoryConsumer {
private AtomicLong recordCounter = new AtomicLong(0);
/** An atomic counter used to keep track of the number of blocks */
private AtomicLong blockCounter = new AtomicLong(0);

private AtomicLong eventIdGenerator = new AtomicLong(0);
// it's part of blockId
private Map<Integer, AtomicInteger> partitionToSeqNo = Maps.newHashMap();
private long askExecutorMemory;
Expand All @@ -96,7 +98,7 @@ public class WriteBufferManager extends MemoryConsumer {
private long requireMemoryInterval;
private int requireMemoryRetryMax;
private Optional<Codec> codec;
private Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc;
private Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc;
private long sendSizeLimit;
private boolean memorySpillEnabled;
private int memorySpillTimeoutSec;
Expand Down Expand Up @@ -138,7 +140,7 @@ public WriteBufferManager(
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf,
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc,
Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc) {
this(
shuffleId,
Expand All @@ -163,7 +165,7 @@ public WriteBufferManager(
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf,
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc,
Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc,
int stageAttemptNumber) {
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
Expand Down Expand Up @@ -212,7 +214,7 @@ public WriteBufferManager(
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf,
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc,
int stageAttemptNumber) {
this(
shuffleId,
Expand Down Expand Up @@ -528,7 +530,12 @@ public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> shuffleBlockI
+ totalSize
+ " bytes");
}
events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent));
events.add(
new AddBlockEvent(
eventIdGenerator.incrementAndGet(),
taskId,
stageAttemptNumber,
shuffleBlockInfosPerEvent));
shuffleBlockInfosPerEvent = Lists.newArrayList();
totalSize = 0;
}
Expand All @@ -543,7 +550,12 @@ public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> shuffleBlockI
+ " bytes");
}
// Use final temporary variables for closures
events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent));
events.add(
new AddBlockEvent(
eventIdGenerator.incrementAndGet(),
taskId,
stageAttemptNumber,
shuffleBlockInfosPerEvent));
}
return events;
}
Expand All @@ -555,15 +567,19 @@ public long spill(long size, MemoryConsumer trigger) {
return 0L;
}

List<CompletableFuture<Long>> futures = spillFunc.apply(clear(bufferSpillRatio));
CompletableFuture<Void> allOfFutures =
CompletableFuture.allOf(futures.toArray(new CompletableFuture[futures.size()]));
List<Future<Long>> futures = spillFunc.apply(clear(bufferSpillRatio));
long end = System.currentTimeMillis() + memorySpillTimeoutSec * 1000;
try {
allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS);
for (Future f : futures) {
f.get(end - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
}
} catch (TimeoutException timeoutException) {
// A best effort strategy to wait.
// If timeout exception occurs, the underlying tasks won't be cancelled.
LOG.warn("[taskId: {}] Spill tasks timeout after {} seconds", taskId, memorySpillTimeoutSec);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOG.warn("[taskId: {}] Spill interrupted due to kill", taskId);
} catch (Exception e) {
LOG.warn("[taskId: {}] Failed to spill buffers due to ", taskId, e);
} finally {
Expand Down Expand Up @@ -608,6 +624,10 @@ public long getBlockCount() {
return blockCounter.get();
}

public Long getLastEventId() {
return eventIdGenerator.get();
}

public void freeAllocatedMemory(long freeMemory) {
freeMemory(freeMemory);
allocatedBytes.addAndGet(-freeMemory);
Expand Down Expand Up @@ -671,8 +691,7 @@ public void setTaskId(String taskId) {
}

@VisibleForTesting
public void setSpillFunc(
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
public void setSpillFunc(Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc) {
this.spillFunc = spillFunc;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -1587,7 +1588,7 @@ public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker() {
return taskToFailedBlockSendTracker;
}

public CompletableFuture<Long> sendData(AddBlockEvent event) {
public Future<Long> sendData(AddBlockEvent event) {
if (dataPusher != null && event != null) {
return dataPusher.send(event);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.Supplier;

import com.google.common.collect.Maps;
Expand Down Expand Up @@ -119,7 +119,7 @@ public void testSendData() throws ExecutionException, InterruptedException {
new ShuffleBlockInfo(1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1);
AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList(shuffleBlockInfo));
// sync send
CompletableFuture<Long> future = dataPusher.send(event);
Future<Long> future = dataPusher.send(event);
long memoryFree = future.get();
assertEquals(100, memoryFree);
assertTrue(taskToSuccessBlockIds.get("taskId").contains(1L));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Stream;
Expand Down Expand Up @@ -370,7 +371,7 @@ public void spillByOwnTest() {
null,
0);

Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc =
blocks -> {
long sum = 0L;
List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
Expand Down Expand Up @@ -481,7 +482,7 @@ public void spillPartial() {
null,
0);

Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc =
blocks -> {
long sum = 0L;
List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
Expand Down Expand Up @@ -579,7 +580,7 @@ public void spillByOwnWithSparkTaskMemoryManagerTest() {

List<ShuffleBlockInfo> blockList = new ArrayList<>();

Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc =
blocks -> {
blockList.addAll(blocks);
long sum = 0L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
Expand Down Expand Up @@ -361,8 +360,7 @@ private void checkSentBlockCount() {
*
* @param shuffleBlockInfoList
*/
private List<CompletableFuture<Long>> processShuffleBlockInfos(
List<ShuffleBlockInfo> shuffleBlockInfoList) {
private List<Future<Long>> processShuffleBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
shuffleBlockInfoList.stream()
.forEach(
Expand Down Expand Up @@ -390,9 +388,8 @@ private List<CompletableFuture<Long>> processShuffleBlockInfos(

// don't send huge block to shuffle server, or there will be OOM if shuffle sever receives data
// more than expected
protected List<CompletableFuture<Long>> postBlockEvent(
List<ShuffleBlockInfo> shuffleBlockInfoList) {
List<CompletableFuture<Long>> futures = new ArrayList<>();
protected List<Future<Long>> postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
List<Future<Long>> futures = new ArrayList<>();
for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
futures.add(shuffleManager.sendData(event));
}
Expand Down
Loading
Loading