diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml
index 639864b196..99564b255b 100644
--- a/client-spark/common/pom.xml
+++ b/client-spark/common/pom.xml
@@ -89,6 +89,12 @@
net.jpountz.lz4
lz4
+
+ org.apache.uniffle
+ rss-common
+ test-jar
+ test
+
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 4a7f653db6..e8998ed431 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -447,6 +447,27 @@ public class RssSparkConfig {
+ "sequence number, the partition id and the task attempt id."))
.createWithDefault(1048576);
+ public static final ConfigEntry RSS_REMOTE_MERGE_ENABLE =
+ createBooleanBuilder(
+ new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_ENABLE)
+ .doc("Whether to enable remote merge"))
+ .createWithDefault(false);
+
+ public static final ConfigEntry RSS_MERGED_BLOCK_SZIE =
+ createIntegerBuilder(
+ new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_MERGED_BLOCK_SZIE)
+ .internal()
+ .doc("The merged block size."))
+ .createWithDefault(RssClientConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT);
+
+ public static final ConfigEntry RSS_REMOTE_MERGE_CLASS_LOADER =
+ createStringBuilder(
+ new ConfigBuilder(
+ SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_CLASS_LOADER)
+ .internal()
+ .doc("The class loader label for remote merge"))
+ .createWithDefault(null);
+
// spark2 doesn't have this key defined
public static final String SPARK_SHUFFLE_COMPRESS_KEY = "spark.shuffle.compress";
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/SparkCombiner.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/SparkCombiner.java
new file mode 100644
index 0000000000..f292fc59d0
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/SparkCombiner.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.spark.Aggregator;
+
+import org.apache.uniffle.client.record.Record;
+import org.apache.uniffle.client.record.writer.Combiner;
+
+public class SparkCombiner extends Combiner {
+
+ private Aggregator agg;
+
+ public SparkCombiner(Aggregator agg) {
+ this.agg = agg;
+ }
+
+ @Override
+ public List combineValues(Iterator>> recordIterator) {
+ List ret = new ArrayList<>();
+ while (recordIterator.hasNext()) {
+ Map.Entry> entry = recordIterator.next();
+ List records = entry.getValue();
+ Record current = null;
+ for (Record record : records) {
+ if (current == null) {
+ // Handle new Key
+ current = Record.create(record.getKey(), agg.createCombiner().apply(record.getValue()));
+ } else {
+ // Combine the values
+ C newValue = agg.mergeValue().apply(current.getValue(), record.getValue());
+ current.setValue(newValue);
+ }
+ }
+ ret.add(current);
+ }
+ return ret;
+ }
+
+ @Override
+ public List combineCombiners(Iterator>> recordIterator) {
+ List ret = new ArrayList<>();
+ while (recordIterator.hasNext()) {
+ Map.Entry> entry = recordIterator.next();
+ List records = entry.getValue();
+ Record current = null;
+ for (Record record : records) {
+ if (current == null) {
+ // Handle new Key
+ current = record;
+ } else {
+ // Combine the values
+ C newValue = agg.mergeCombiners().apply(current.getValue(), record.getValue());
+ current.setValue(newValue);
+ }
+ }
+ ret.add(current);
+ }
+ return ret;
+ }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleDataIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleDataIterator.java
new file mode 100644
index 0000000000..dc8a429f3e
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleDataIterator.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.reader;
+
+import java.io.IOException;
+
+import scala.collection.AbstractIterator;
+import scala.runtime.BoxedUnit;
+
+import org.apache.uniffle.client.record.Record;
+import org.apache.uniffle.client.record.reader.KeyValueReader;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.common.exception.RssException;
+
+public class RMRssShuffleDataIterator extends AbstractIterator> {
+
+ private RMRecordsReader reader;
+ private KeyValueReader keyValueReader;
+
+ public RMRssShuffleDataIterator(RMRecordsReader reader) {
+ this.reader = reader;
+ this.keyValueReader = reader.keyValueReader();
+ }
+
+ @Override
+ public boolean hasNext() {
+ try {
+ return this.keyValueReader.hasNext();
+ } catch (IOException e) {
+ throw new RssException(e);
+ }
+ }
+
+ @Override
+ public Record next() {
+ try {
+ return this.keyValueReader.next();
+ } catch (IOException e) {
+ throw new RssException(e);
+ }
+ }
+
+ public BoxedUnit cleanup() {
+ reader.close();
+ return BoxedUnit.UNIT;
+ }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/RMWriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/RMWriteBufferManager.java
new file mode 100644
index 0000000000..7a3c874097
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/RMWriteBufferManager.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Output;
+import com.esotericsoftware.kryo.io.UnsafeOutput;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.Serializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.serializer.kryo.KryoSerializerInstance;
+import org.apache.uniffle.common.util.ChecksumUtils;
+
+public class RMWriteBufferManager extends WriteBufferManager {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RMWriteBufferManager.class);
+
+ private Comparator comparator;
+ private KryoSerializerInstance instance;
+ private Kryo serKryo;
+ private Output serOutput;
+
+ public RMWriteBufferManager(
+ int shuffleId,
+ String taskId,
+ long taskAttemptId,
+ BufferManagerOptions bufferManagerOptions,
+ Serializer serializer,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleWriteMetrics shuffleWriteMetrics,
+ RssConf rssConf,
+ Function, List>> spillFunc,
+ Function> partitionAssignmentRetrieveFunc,
+ int stageAttemptNumber,
+ Comparator comparator) {
+ super(
+ shuffleId,
+ taskId,
+ taskAttemptId,
+ bufferManagerOptions,
+ serializer,
+ taskMemoryManager,
+ shuffleWriteMetrics,
+ rssConf,
+ spillFunc,
+ partitionAssignmentRetrieveFunc,
+ stageAttemptNumber);
+ this.comparator = comparator;
+ this.instance = new KryoSerializerInstance(this.rssConf);
+ this.serKryo = this.instance.borrowKryo();
+ this.serOutput = new UnsafeOutput(this.arrayOutputStream);
+ }
+
+ @Override
+ public List addRecord(int partitionId, Object key, Object value) {
+ try {
+ final long start = System.currentTimeMillis();
+ arrayOutputStream.reset();
+ this.serKryo.writeClassAndObject(this.serOutput, key);
+ this.serOutput.flush();
+ final int keyLength = arrayOutputStream.size();
+ this.serKryo.writeClassAndObject(this.serOutput, value);
+ this.serOutput.flush();
+ serializeTime += System.currentTimeMillis() - start;
+ byte[] serializedData = arrayOutputStream.getBuf();
+ int serializedDataLength = arrayOutputStream.size();
+ if (serializedDataLength == 0) {
+ return null;
+ }
+ List shuffleBlockInfos =
+ addPartitionData(
+ partitionId, serializedData, keyLength, serializedDataLength - keyLength, start);
+ // records is a row based semantic, when in columnar shuffle records num should be taken from
+ // ColumnarBatch
+ // that is handled by rss shuffle writer implementation
+ if (isRowBased) {
+ shuffleWriteMetrics.incRecordsWritten(1L);
+ }
+ return shuffleBlockInfos;
+ } catch (Exception e) {
+ throw new RssException(e);
+ }
+ }
+
+ private List addPartitionData(
+ int partitionId, byte[] serializedData, int keyLength, int valueLength, long start) {
+ List singleOrEmptySendingBlocks =
+ insertIntoBuffer(partitionId, serializedData, keyLength, valueLength);
+
+ // check buffer size > spill threshold
+ if (usedBytes.get() - inSendListBytes.get() > spillSize) {
+ LOG.info(
+ "ShuffleBufferManager spill for buffer size exceeding spill threshold, "
+ + "usedBytes[{}], inSendListBytes[{}], spill size threshold[{}]",
+ usedBytes.get(),
+ inSendListBytes.get(),
+ spillSize);
+ List multiSendingBlocks = clear(bufferSpillRatio);
+ multiSendingBlocks.addAll(singleOrEmptySendingBlocks);
+ writeTime += System.currentTimeMillis() - start;
+ return multiSendingBlocks;
+ }
+ writeTime += System.currentTimeMillis() - start;
+ return singleOrEmptySendingBlocks;
+ }
+
+ private List insertIntoBuffer(
+ int partitionId, byte[] serializedData, int keyLength, int valueLength) {
+ int recordLength = keyLength + valueLength;
+ long required = Math.max(bufferSegmentSize, recordLength);
+ // Asking memory from task memory manager for the existing writer buffer,
+ // this may trigger current WriteBufferManager spill method, which will
+ // make the current write buffer discard. So we have to recheck the buffer existence.
+ boolean hasRequested = false;
+ WriterBuffer wb = buffers.get(partitionId);
+ if (wb != null) {
+ if (wb.askForMemory(recordLength)) {
+ requestMemory(required);
+ hasRequested = true;
+ }
+ }
+
+ // hasRequested is not true means spill method was not trigger,
+ // and we don't have to recheck the buffer existence in this case.
+ if (hasRequested) {
+ wb = buffers.get(partitionId);
+ }
+
+ if (wb != null) {
+ if (hasRequested) {
+ usedBytes.addAndGet(required);
+ }
+ wb.addRecord(serializedData, keyLength, valueLength);
+ } else {
+ // The true of hasRequested means the former partitioned buffer has been flushed, that is
+ // triggered by the spill operation caused by asking for memory. So it needn't to re-request
+ // the memory.
+ if (!hasRequested) {
+ requestMemory(required);
+ }
+ usedBytes.addAndGet(required);
+ wb = new WriterBuffer(bufferSegmentSize);
+ wb.addRecord(serializedData, keyLength, valueLength);
+ buffers.put(partitionId, wb);
+ }
+
+ if (wb.getMemoryUsed() > bufferSize) {
+ List sentBlocks = new ArrayList<>(1);
+ sentBlocks.add(createShuffleBlock(partitionId, wb));
+ recordCounter.addAndGet(wb.getRecordCount());
+ copyTime += wb.getCopyTime();
+ sortRecordTime += wb.getSortTime();
+ buffers.remove(partitionId);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(
+ "Single buffer is full for shuffleId["
+ + shuffleId
+ + "] partition["
+ + partitionId
+ + "] with memoryUsed["
+ + wb.getMemoryUsed()
+ + "], dataLength["
+ + wb.getDataLength()
+ + "]");
+ }
+ return sentBlocks;
+ }
+ return Collections.emptyList();
+ }
+
+ // transform records to shuffleBlock
+ @Override
+ protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb) {
+ byte[] data = wb.getData(instance, comparator);
+ final int length = data.length;
+ final long crc32 = ChecksumUtils.getCrc32(data);
+ final long blockId =
+ blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId, taskAttemptId);
+ blockCounter.incrementAndGet();
+ uncompressedDataLen += data.length;
+ shuffleWriteMetrics.incBytesWritten(data.length);
+ // add memory to indicate bytes which will be sent to shuffle server
+ inSendListBytes.addAndGet(wb.getMemoryUsed());
+ return new ShuffleBlockInfo(
+ shuffleId,
+ partitionId,
+ blockId,
+ length,
+ crc32,
+ data,
+ partitionAssignmentRetrieveFunc.apply(partitionId),
+ length,
+ wb.getMemoryUsed(),
+ taskAttemptId);
+ }
+
+ @Override
+ public void freeAllMemory() {
+ super.freeAllMemory();
+ if (this.instance != null && this.serKryo != null) {
+ this.instance.releaseKryo(this.serKryo);
+ }
+ if (this.serOutput != null) {
+ this.serOutput.close();
+ }
+ }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index cf4b4bc51c..9143620dac 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -58,40 +58,42 @@
public class WriteBufferManager extends MemoryConsumer {
private static final Logger LOG = LoggerFactory.getLogger(WriteBufferManager.class);
- private int bufferSize;
- private long spillSize;
+ protected int bufferSize;
+ protected long spillSize;
// allocated bytes from executor memory
private AtomicLong allocatedBytes = new AtomicLong(0);
// bytes of shuffle data in memory
- private AtomicLong usedBytes = new AtomicLong(0);
+ protected final AtomicLong usedBytes = new AtomicLong(0);
// bytes of shuffle data which is in send list
- private AtomicLong inSendListBytes = new AtomicLong(0);
+ protected final AtomicLong inSendListBytes = new AtomicLong(0);
/** An atomic counter used to keep track of the number of records */
- private AtomicLong recordCounter = new AtomicLong(0);
+ protected final AtomicLong recordCounter = new AtomicLong(0);
/** An atomic counter used to keep track of the number of blocks */
- private AtomicLong blockCounter = new AtomicLong(0);
+ protected AtomicLong blockCounter = new AtomicLong(0);
// it's part of blockId
private Map partitionToSeqNo = Maps.newHashMap();
private long askExecutorMemory;
- private int shuffleId;
+ protected final int shuffleId;
private String taskId;
- private long taskAttemptId;
+ protected final long taskAttemptId;
private SerializerInstance instance;
- private ShuffleWriteMetrics shuffleWriteMetrics;
+ protected ShuffleWriteMetrics shuffleWriteMetrics;
+ protected final RssConf rssConf;
// cache partition -> records
- private Map buffers;
+ protected final Map buffers;
private int serializerBufferSize;
- private int bufferSegmentSize;
- private long copyTime = 0;
- private long serializeTime = 0;
+ protected final int bufferSegmentSize;
+ protected long copyTime = 0;
+ protected long sortRecordTime = 0;
+ protected long serializeTime = 0;
private long compressTime = 0;
private long sortTime = 0;
- private long writeTime = 0;
+ protected long writeTime = 0;
private long estimateTime = 0;
private long requireMemoryTime = 0;
private SerializationStream serializeStream;
- private WrappedByteArrayOutputStream arrayOutputStream;
- private long uncompressedDataLen = 0;
+ protected final WrappedByteArrayOutputStream arrayOutputStream;
+ protected long uncompressedDataLen = 0;
private long compressedDataLen = 0;
private long requireMemoryInterval;
private int requireMemoryRetryMax;
@@ -100,10 +102,10 @@ public class WriteBufferManager extends MemoryConsumer {
private long sendSizeLimit;
private boolean memorySpillEnabled;
private int memorySpillTimeoutSec;
- private boolean isRowBased;
- private BlockIdLayout blockIdLayout;
- private double bufferSpillRatio;
- private Function> partitionAssignmentRetrieveFunc;
+ protected boolean isRowBased;
+ protected BlockIdLayout blockIdLayout;
+ protected double bufferSpillRatio;
+ protected Function> partitionAssignmentRetrieveFunc;
private int stageAttemptNumber;
public WriteBufferManager(
@@ -174,6 +176,7 @@ public WriteBufferManager(
this.taskId = taskId;
this.taskAttemptId = taskAttemptId;
this.shuffleWriteMetrics = shuffleWriteMetrics;
+ this.rssConf = rssConf;
this.serializerBufferSize = bufferManagerOptions.getSerializerBufferSize();
this.bufferSegmentSize = bufferManagerOptions.getBufferSegmentSize();
this.askExecutorMemory = bufferManagerOptions.getPreAllocatedBufferSize();
@@ -310,6 +313,7 @@ private List insertIntoBuffer(
sentBlocks.add(createShuffleBlock(partitionId, wb));
recordCounter.addAndGet(wb.getRecordCount());
copyTime += wb.getCopyTime();
+ sortRecordTime += wb.getSortTime();
buffers.remove(partitionId);
if (LOG.isDebugEnabled()) {
LOG.debug(
@@ -449,13 +453,13 @@ protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb)
}
// it's run in single thread, and is not thread safe
- private int getNextSeqNo(int partitionId) {
+ protected int getNextSeqNo(int partitionId) {
return partitionToSeqNo
.computeIfAbsent(partitionId, k -> new AtomicInteger(0))
.getAndIncrement();
}
- private void requestMemory(long requiredMem) {
+ protected void requestMemory(long requiredMem) {
final long start = System.currentTimeMillis();
if (allocatedBytes.get() - usedBytes.get() < requiredMem) {
requestExecutorMemory(requiredMem);
@@ -644,6 +648,8 @@ public long getWriteTime() {
public String getManagerCostInfo() {
return "WriteBufferManager cost copyTime["
+ copyTime
+ + "], sortRecordTime["
+ + sortRecordTime
+ "], writeTime["
+ writeTime
+ "], serializeTime["
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
index ac6ac9e271..2fe21b98f9 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
@@ -17,23 +17,34 @@
package org.apache.spark.shuffle.writer;
+import java.util.ArrayList;
+import java.util.Comparator;
import java.util.List;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.UnsafeInput;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.serializer.kryo.KryoSerializerInstance;
+
public class WriterBuffer {
private static final Logger LOG = LoggerFactory.getLogger(WriterBuffer.class);
- private long copyTime = 0;
- private byte[] buffer;
- private int bufferSize;
- private int nextOffset = 0;
- private List buffers = Lists.newArrayList();
- private int dataLength = 0;
- private int memoryUsed = 0;
- private long recordCount = 0;
+ protected long copyTime = 0;
+ protected byte[] buffer;
+ protected int bufferSize;
+ protected int nextOffset = 0;
+ protected List buffers = Lists.newArrayList();
+ protected int dataLength = 0;
+ protected int memoryUsed = 0;
+ protected long recordCount = 0;
+
+ private long sortTime = 0;
+ private final List records = new ArrayList();
public WriterBuffer(int bufferSize) {
this.bufferSize = bufferSize;
@@ -70,6 +81,13 @@ public void addRecord(byte[] recordBuffer, int length) {
recordCount++;
}
+ public void addRecord(byte[] recordBuffer, int keyLength, int valueLength) {
+ this.addRecord(recordBuffer, keyLength + valueLength);
+ this.records.add(
+ new Record(
+ this.buffers.size(), nextOffset - keyLength - valueLength, keyLength, valueLength));
+ }
+
public boolean askForMemory(long length) {
return buffer == null || nextOffset + length > bufferSize;
}
@@ -88,6 +106,51 @@ public byte[] getData() {
return data;
}
+ public byte[] getData(KryoSerializerInstance instance, Comparator comparator) {
+ if (comparator != null) {
+ // deserialized key
+ long start = System.currentTimeMillis();
+ Kryo derKryo = null;
+ try {
+ derKryo = instance.borrowKryo();
+ Input input = new UnsafeInput();
+ for (Record record : records) {
+ byte[] bytes =
+ record.index == this.buffers.size() ? buffer : this.buffers.get(record.index).buffer;
+ input.setBuffer(bytes, record.offset, record.keyLength);
+ record.key = derKryo.readClassAndObject(input);
+ }
+ } catch (Throwable e) {
+ throw new RssException(e);
+ } finally {
+ instance.releaseKryo(derKryo);
+ }
+
+ // sort by key
+ this.records.sort(
+ new Comparator() {
+ @Override
+ public int compare(Record r1, Record r2) {
+ return comparator.compare(r1.key, r2.key);
+ }
+ });
+ sortTime += System.currentTimeMillis() - start;
+ }
+
+ // write
+ long start = System.currentTimeMillis();
+ byte[] data = new byte[dataLength];
+ int offset = 0;
+ for (Record record : records) {
+ byte[] bytes =
+ record.index == buffers.size() ? buffer : buffers.get(record.index).getBuffer();
+ System.arraycopy(bytes, record.offset, data, offset, record.keyLength + record.valueLength);
+ offset += record.keyLength + record.valueLength;
+ }
+ copyTime += System.currentTimeMillis() - start;
+ return data;
+ }
+
public int getDataLength() {
return dataLength;
}
@@ -104,6 +167,10 @@ public long getRecordCount() {
return recordCount;
}
+ public long getSortTime() {
+ return sortTime;
+ }
+
private static final class WrappedBuffer {
byte[] buffer;
@@ -122,4 +189,19 @@ public int getSize() {
return size;
}
}
+
+ private static final class Record {
+ private final int index;
+ private final int offset;
+ private final int keyLength;
+ private final int valueLength;
+ private Object key = null;
+
+ Record(int keyIndex, int offset, int keyLength, int valueLength) {
+ this.index = keyIndex;
+ this.offset = offset;
+ this.keyLength = keyLength;
+ this.valueLength = valueLength;
+ }
+ }
}
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index d869c64fe6..d7e95ae2e1 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -17,11 +17,14 @@
package org.apache.uniffle.shuffle.manager;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@@ -39,6 +42,8 @@
import java.util.stream.Collectors;
import scala.Tuple2;
+import scala.math.Ordering;
+import scala.runtime.AbstractFunction1;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
@@ -49,6 +54,7 @@
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
+import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
@@ -100,6 +106,7 @@
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
import org.apache.uniffle.shuffle.BlockIdManager;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
@@ -169,6 +176,9 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
protected GrpcServer shuffleManagerServer;
protected DataPusher dataPusher;
+ protected boolean remoteMergeEnable;
+ protected Map dependencies = new HashMap();
+
public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
LOG.info(
"Uniffle {} version: {}", this.getClass().getName(), Constants.VERSION_AND_REVISION_SHORT);
@@ -327,6 +337,7 @@ public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
this.rssStageResubmitManager = new RssStageResubmitManager();
+ this.remoteMergeEnable = sparkConf.get(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE);
}
@VisibleForTesting
@@ -965,7 +976,8 @@ public boolean reassignOnStageResubmit(
rssStageResubmitManager.getServerIdBlackList(),
stageAttemptId,
stageAttemptNumber,
- false);
+ false,
+ buildMergeContext(dependencies.get(shuffleId)));
MutableShuffleHandleInfo shuffleHandleInfo =
new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo());
StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
@@ -1082,7 +1094,11 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
"Register the new partition->servers assignment on reassign. {}",
newServerToPartitions);
registerShuffleServers(
- getAppId(), shuffleId, newServerToPartitions, getRemoteStorageInfo());
+ getAppId(),
+ shuffleId,
+ newServerToPartitions,
+ getRemoteStorageInfo(),
+ buildMergeContext(dependencies.get(shuffleId)));
}
LOG.info(
@@ -1248,7 +1264,8 @@ private Set reassignServerForTask(
},
stageId,
stageAttemptNumber,
- reassign);
+ reassign,
+ buildMergeContext(dependencies.get(shuffleId)));
return replacementsRef.get();
}
@@ -1262,7 +1279,8 @@ private Map> requestShuffleAssignment(
Function reassignmentHandler,
int stageId,
int stageAttemptNumber,
- boolean reassign) {
+ boolean reassign,
+ MergeContext mergeContext) {
Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);
@@ -1290,7 +1308,11 @@ private Map> requestShuffleAssignment(
response = reassignmentHandler.apply(response);
}
registerShuffleServers(
- getAppId(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
+ getAppId(),
+ shuffleId,
+ response.getServerToPartitionRanges(),
+ getRemoteStorageInfo(),
+ mergeContext);
return response.getPartitionToServers();
} catch (Throwable throwable) {
throw new RssException("registerShuffle failed!", throwable);
@@ -1306,7 +1328,8 @@ protected Map> requestShuffleAssignment(
Set faultyServerIds,
int stageId,
int stageAttemptNumber,
- boolean reassign) {
+ boolean reassign,
+ MergeContext mergeContext) {
Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);
@@ -1339,7 +1362,8 @@ protected Map> requestShuffleAssignment(
shuffleId,
response.getServerToPartitionRanges(),
getRemoteStorageInfo(),
- stageAttemptNumber);
+ stageAttemptNumber,
+ mergeContext);
return response.getPartitionToServers();
},
retryInterval,
@@ -1349,32 +1373,13 @@ protected Map> requestShuffleAssignment(
}
}
- protected Map> requestShuffleAssignment(
- int shuffleId,
- int partitionNum,
- int partitionNumPerRange,
- int assignmentShuffleServerNumber,
- int estimateTaskConcurrency,
- Set faultyServerIds,
- int stageAttemptNumber) {
- return requestShuffleAssignment(
- shuffleId,
- partitionNum,
- partitionNumPerRange,
- assignmentShuffleServerNumber,
- estimateTaskConcurrency,
- faultyServerIds,
- -1,
- stageAttemptNumber,
- false);
- }
-
protected void registerShuffleServers(
String appId,
int shuffleId,
Map> serverToPartitionRanges,
RemoteStorageInfo remoteStorage,
- int stageAttemptNumber) {
+ int stageAttemptNumber,
+ MergeContext mergeContext) {
if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
return;
}
@@ -1393,7 +1398,7 @@ protected void registerShuffleServers(
ShuffleDataDistributionType.NORMAL,
maxConcurrencyPerPartitionToWrite,
stageAttemptNumber,
- null,
+ mergeContext,
sparkConfMap);
});
LOG.info(
@@ -1405,7 +1410,8 @@ protected void registerShuffleServers(
String appId,
int shuffleId,
Map> serverToPartitionRanges,
- RemoteStorageInfo remoteStorage) {
+ RemoteStorageInfo remoteStorage,
+ MergeContext mergeContext) {
if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
return;
}
@@ -1425,6 +1431,8 @@ protected void registerShuffleServers(
remoteStorage,
dataDistributionType,
maxConcurrencyPerPartitionToWrite,
+ 0,
+ mergeContext,
sparkConfMap);
});
LOG.info(
@@ -1538,6 +1546,48 @@ public boolean isValidTask(String taskId) {
return !failedTaskIds.contains(taskId);
}
+ protected String encode(Ordering obj) {
+ try {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ ObjectOutputStream oos = new ObjectOutputStream(baos);
+ oos.writeObject(obj);
+ oos.close();
+ return "#" + Base64.getEncoder().encodeToString(baos.toByteArray());
+ } catch (Exception e) {
+ throw new RssException(e);
+ }
+ }
+
+ protected MergeContext buildMergeContext(ShuffleDependency dependency) {
+ if (remoteMergeEnable) {
+ MergeContext mergeContext =
+ MergeContext.newBuilder()
+ .setKeyClass(dependency.keyClassName())
+ .setValueClass(
+ dependency.mapSideCombine()
+ ? dependency.combinerClassName().get()
+ : dependency.valueClassName())
+ .setComparatorClass(
+ dependency
+ .keyOrdering()
+ .map(
+ new AbstractFunction1, String>() {
+ @Override
+ public String apply(Ordering o) {
+ return encode(o);
+ }
+ })
+ .get())
+ .setMergedBlockSize(sparkConf.get(RssSparkConfig.RSS_MERGED_BLOCK_SZIE))
+ .setMergeClassLoader(
+ sparkConf.get(RssSparkConfig.RSS_REMOTE_MERGE_CLASS_LOADER.key(), ""))
+ .build();
+ return mergeContext;
+ } else {
+ return null;
+ }
+ }
+
@VisibleForTesting
public void setDataPusher(DataPusher dataPusher) {
this.dataPusher = dataPusher;
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/SparkCombinerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/SparkCombinerTest.java
new file mode 100644
index 0000000000..24f739ce48
--- /dev/null
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/SparkCombinerTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle;
+
+import java.util.Iterator;
+
+import scala.runtime.AbstractFunction1;
+import scala.runtime.AbstractFunction2;
+
+import org.apache.spark.Aggregator;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.client.record.Record;
+import org.apache.uniffle.client.record.RecordBlob;
+import org.apache.uniffle.client.record.RecordBuffer;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class SparkCombinerTest {
+
+ @Test
+ public void testSparkCombiner() {
+ RecordBuffer recordBuffer = new RecordBuffer(-1);
+ for (int i = 0; i < 5; i++) {
+ for (int j = 0; j <= i; j++) {
+ String key = "key" + i;
+ recordBuffer.addRecord(key, j);
+ }
+ }
+ SparkCombiner combiner =
+ new SparkCombiner<>(
+ new Aggregator<>(
+ new AbstractFunction1() {
+ @Override
+ public String apply(Integer v1) {
+ return v1.toString();
+ }
+ },
+ new AbstractFunction2() {
+ @Override
+ public String apply(String c, Integer v) {
+ return Integer.valueOf(Integer.parseInt(c) + v).toString();
+ }
+ },
+ new AbstractFunction2() {
+ @Override
+ public String apply(String c1, String c2) {
+ return Integer.valueOf(Integer.parseInt(c1) + Integer.parseInt(c2)).toString();
+ }
+ }));
+ RecordBlob recordBlob = new RecordBlob(-1);
+ recordBlob.addRecords(recordBuffer);
+ recordBlob.combine(combiner, false);
+ Iterator> newRecords = recordBlob.getResult().iterator();
+ int index = 0;
+ while (newRecords.hasNext()) {
+ Record record = newRecords.next();
+ int expectedValue = 0;
+ for (int i = 0; i <= index; i++) {
+ expectedValue += i;
+ }
+ assertEquals("Record{key=key" + index + ", value=" + expectedValue + "}", record.toString());
+ index++;
+ }
+ assertEquals(5, index);
+ }
+}
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 97639d3c44..8056278ac8 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -20,7 +20,10 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
@@ -28,6 +31,8 @@
import java.util.stream.Stream;
import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics;
@@ -47,13 +52,19 @@
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.records.RecordsReader;
+import org.apache.uniffle.common.serializer.SerInputStream;
+import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.util.BlockIdLayout;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
@@ -62,6 +73,8 @@
public class WriteBufferManagerTest {
+ private static final int RECORD_NUM = 1009;
+
private WriteBufferManager createManager(SparkConf conf) {
Serializer kryoSerializer = new KryoSerializer(conf);
TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
@@ -622,4 +635,68 @@ public void addFirstRecordWithLargeSizeTest() {
List shuffleBlockInfos2 = wbm.addRecord(1, testKey, testValue2);
assertEquals(0, shuffleBlockInfos2.size());
}
+
+ @Test
+ public void testWriteRemoteMerge() throws Exception {
+ final SparkConf conf = new SparkConf();
+ final RssConf rssConf = RssSparkConfig.toRssConf(conf);
+ BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+ TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+ doReturn(16L * 1024L * 1024L)
+ .when(mockTaskMemoryManager)
+ .acquireExecutionMemory(anyLong(), any());
+
+ Map> partitionToServers = new HashMap<>();
+ WriteBufferManager manager =
+ new RMWriteBufferManager(
+ 0,
+ "",
+ 0,
+ bufferOptions,
+ new KryoSerializer(conf),
+ mockTaskMemoryManager,
+ new ShuffleWriteMetrics(),
+ RssSparkConfig.toRssConf(conf),
+ null,
+ pid -> partitionToServers.get(pid),
+ 0,
+ SerializerUtils.getComparator(String.class));
+ List shuffleBlockInfos;
+ List indexes = new ArrayList<>();
+ for (int i = 0; i < RECORD_NUM; i++) {
+ indexes.add(i);
+ }
+ Collections.shuffle(indexes);
+ for (Integer index : indexes) {
+ shuffleBlockInfos =
+ manager.addRecord(
+ 0,
+ SerializerUtils.genData(String.class, index),
+ SerializerUtils.genData(Integer.class, index));
+ assertTrue(CollectionUtils.isEmpty(shuffleBlockInfos));
+ }
+
+ shuffleBlockInfos = manager.clear();
+ assertFalse(CollectionUtils.isEmpty(shuffleBlockInfos));
+
+ // check blocks
+ List events = manager.buildBlockEvents(shuffleBlockInfos);
+ assertEquals(1, events.size());
+ List blocks = events.get(0).getShuffleDataInfoList();
+ assertEquals(1, blocks.size());
+
+ ByteBuf buf = blocks.get(0).getData();
+ RecordsReader reader =
+ new RecordsReader<>(
+ rssConf, SerInputStream.newInputStream(buf), String.class, Integer.class, false, false);
+ reader.init();
+ int index = 0;
+ while (reader.next()) {
+ assertEquals(SerializerUtils.genData(String.class, index), reader.getCurrentKey());
+ assertEquals(SerializerUtils.genData(Integer.class, index), reader.getCurrentValue());
+ index++;
+ }
+ reader.close();
+ assertEquals(RECORD_NUM, index);
+ }
}
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java
index d4533efaf6..9fd797fc1e 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java
@@ -17,18 +17,38 @@
package org.apache.spark.shuffle.writer;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
import scala.reflect.ClassTag$;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Output;
+import com.esotericsoftware.kryo.io.UnsafeOutput;
+import io.netty.buffer.Unpooled;
import org.apache.spark.SparkConf;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.Serializer;
import org.junit.jupiter.api.Test;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.serializer.DeserializationStream;
+import org.apache.uniffle.common.serializer.SerInputStream;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.serializer.kryo.KryoSerializerInstance;
+
+import static org.apache.uniffle.common.serializer.SerializerUtils.genData;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
public class WriteBufferTest {
+ private static final int RECORDS_NUM = 200;
+
private SparkConf conf = new SparkConf(false);
private Serializer kryoSerializer = new KryoSerializer(conf);
private WrappedByteArrayOutputStream arrayOutputStream = new WrappedByteArrayOutputStream(32);
@@ -37,6 +57,10 @@ public class WriteBufferTest {
private byte[] serializedData;
private int serializedDataLength;
+ private int keyLength;
+ private int valueLength;
+ private byte[] sortSerializedData;
+
@Test
public void test() {
WriterBuffer wb = new WriterBuffer(32);
@@ -83,6 +107,50 @@ public void test() {
assertEquals(91, wb.getData().length);
}
+ @Test
+ public void testSortRecords() throws IOException {
+ arrayOutputStream.reset();
+ KryoSerializerInstance instance = new KryoSerializerInstance(new RssConf());
+ Kryo serKryo = instance.borrowKryo();
+ Output serOutput = new UnsafeOutput(this.arrayOutputStream);
+
+ WriterBuffer wb = new WriterBuffer(1024);
+ assertEquals(0, wb.getMemoryUsed());
+ assertEquals(0, wb.getDataLength());
+
+ List arrays = new ArrayList<>();
+ for (int i = 0; i < RECORDS_NUM; i++) {
+ arrays.add(i);
+ }
+ Collections.shuffle(arrays);
+ for (int i : arrays) {
+ String key = (String) SerializerUtils.genData(String.class, i);
+ int value = (int) SerializerUtils.genData(int.class, i);
+ serializeData(key, value, serKryo, serOutput);
+ wb.addRecord(sortSerializedData, keyLength, valueLength);
+ }
+ assertEquals(RECORDS_NUM, wb.getRecordCount());
+ assertEquals(15 * RECORDS_NUM, wb.getDataLength());
+
+ byte[] data = wb.getData(instance, SerializerUtils.getComparator(String.class));
+ assertEquals(15 * RECORDS_NUM, data.length);
+ // deserialized
+ DeserializationStream dStream =
+ instance.deserializeStream(
+ SerInputStream.newInputStream(Unpooled.wrappedBuffer(data)),
+ String.class,
+ int.class,
+ false,
+ false);
+ dStream.init();
+ for (int i = 0; i < RECORDS_NUM; i++) {
+ assertTrue(dStream.nextRecord());
+ assertEquals(genData(String.class, i), dStream.getCurrentKey());
+ assertEquals(i, dStream.getCurrentValue());
+ }
+ assertFalse(dStream.nextRecord());
+ }
+
private void serializeData(Object key, Object value) {
arrayOutputStream.reset();
serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass()));
@@ -91,4 +159,15 @@ private void serializeData(Object key, Object value) {
serializedData = arrayOutputStream.getBuf();
serializedDataLength = arrayOutputStream.size();
}
+
+ private void serializeData(Object key, Object value, Kryo serKryo, Output serOutput) {
+ arrayOutputStream.reset();
+ serKryo.writeClassAndObject(serOutput, key);
+ serOutput.flush();
+ keyLength = arrayOutputStream.size();
+ serKryo.writeClassAndObject(serOutput, value);
+ serOutput.flush();
+ valueLength = arrayOutputStream.size() - keyLength;
+ sortSerializedData = arrayOutputStream.getBuf();
+ }
}
diff --git a/client-spark/common/src/test/resources/log4j2.xml b/client-spark/common/src/test/resources/log4j2.xml
new file mode 100644
index 0000000000..8db107a47f
--- /dev/null
+++ b/client-spark/common/src/test/resources/log4j2.xml
@@ -0,0 +1,29 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 8e6b8dfca3..a669dd74c0 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -154,7 +154,10 @@ public ShuffleHandle registerShuffle(
requiredShuffleServerNumber,
estimateTaskConcurrency,
rssStageResubmitManager.getServerIdBlackList(),
- 0);
+ -1,
+ 0,
+ false,
+ null);
startHeartbeat();
@@ -379,17 +382,6 @@ private Roaring64NavigableMap getShuffleResult(
}
}
- private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
- Set faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
- faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
- Map> partitionToServers =
- requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, 0);
- if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) {
- return partitionToServers.get(0).get(0);
- }
- return null;
- }
-
@Override
protected ShuffleWriteClient createShuffleWriteClient() {
int unregisterThreadPoolSize =
diff --git a/client-spark/spark3/pom.xml b/client-spark/spark3/pom.xml
index 94a124d0bc..f29b68598b 100644
--- a/client-spark/spark3/pom.xml
+++ b/client-spark/spark3/pom.xml
@@ -98,5 +98,17 @@
${spark.version}
provided
+
+ org.apache.uniffle
+ rss-common
+ test-jar
+ test
+
+
+ org.apache.uniffle
+ rss-client
+ test-jar
+ test
+
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 5e2a941029..b8cf0769bb 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -45,6 +45,7 @@
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
+import org.apache.spark.shuffle.reader.RMRssShuffleReader;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.shuffle.writer.DataPusher;
import org.apache.spark.shuffle.writer.RssShuffleWriter;
@@ -150,6 +151,7 @@ public ShuffleHandle registerShuffle(
return new RssShuffleHandle<>(
shuffleId, id.get(), dependency.rdd().getNumPartitions(), dependency, hdlInfoBd);
}
+ dependencies.put(shuffleId, dependency);
String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
RemoteStorageInfo defaultRemoteStorage = getDefaultRemoteStorageInfo(sparkConf);
@@ -173,7 +175,10 @@ public ShuffleHandle registerShuffle(
requiredShuffleServerNumber,
estimateTaskConcurrency,
rssStageResubmitManager.getServerIdBlackList(),
- 0);
+ -1,
+ 0,
+ false,
+ buildMergeContext(dependency));
startHeartbeat();
shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length);
@@ -396,6 +401,22 @@ public ShuffleReader getReaderImpl(
Configuration readerHadoopConf =
RssSparkShuffleUtils.getRemoteStorageHadoopConf(sparkConf, shuffleRemoteStorageInfo);
+ if (remoteMergeEnable) {
+ return new RMRssShuffleReader<>(
+ startPartition,
+ endPartition,
+ context,
+ rssShuffleHandle,
+ shuffleWriteClient,
+ shuffleHandleInfo.getAllPartitionServersForReader(),
+ RssUtils.generatePartitionToBitmap(
+ blockIdBitmap, startPartition, endPartition, blockIdLayout),
+ taskIdBitmap,
+ readMetrics,
+ managerClientSupplier,
+ RssSparkConfig.toRssConf(sparkConf),
+ clientType);
+ }
return new RssShuffleReader(
startPartition,
endPartition,
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleReader.java
new file mode 100644
index 0000000000..ae45b4e1a7
--- /dev/null
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleReader.java
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.reader;
+
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Supplier;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.AbstractIterator;
+import scala.collection.Iterator;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.commons.lang3.ClassUtils;
+import org.apache.spark.Aggregator;
+import org.apache.spark.InterruptibleIterator;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleReadMetrics;
+import org.apache.spark.shuffle.RssShuffleHandle;
+import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.shuffle.SparkCombiner;
+import org.apache.spark.util.CompletionIterator;
+import org.apache.spark.util.CompletionIterator$;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.Record;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.client.record.writer.Combiner;
+import org.apache.uniffle.client.util.DefaultIdHelper;
+import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.BlockIdLayout;
+
+public class RMRssShuffleReader implements ShuffleReader {
+ private static final Logger LOG = LoggerFactory.getLogger(RMRssShuffleReader.class);
+ private final Map> partitionToShuffleServers;
+
+ private final String appId;
+ private final int shuffleId;
+ private final int startPartition;
+ private final int endPartition;
+ private final TaskContext context;
+ private final ShuffleDependency shuffleDependency;
+ private final int numMaps;
+ private final String taskId;
+ private final ShuffleWriteClient shuffleWriteClient;
+ private final Map> partitionToServerInfos;
+ private final Map partitionToExpectBlocks;
+ private final Roaring64NavigableMap taskIdBitmap;
+ private final ShuffleReadMetrics readMetrics;
+ private final Supplier managerClientSupplier;
+
+ private final RssConf rssConf;
+ private final String clientType;
+
+ private final Class keyClass;
+ private final Class valueClass;
+ private final boolean isMapCombine;
+ private Comparator comparator = null;
+ private Combiner combiner = null;
+ private DefaultIdHelper idHelper;
+
+ private Object metricsLock = new Object();
+
+ public RMRssShuffleReader(
+ int startPartition,
+ int endPartition,
+ TaskContext context,
+ RssShuffleHandle rssShuffleHandle,
+ ShuffleWriteClient shuffleWriteClient,
+ Map> partitionToServerInfos,
+ Map partitionToExpectBlocks,
+ Roaring64NavigableMap taskIdBitmap,
+ ShuffleReadMetrics readMetrics,
+ Supplier managerClientSupplier,
+ RssConf rssConf,
+ String clientType) {
+ this.appId = rssShuffleHandle.getAppId();
+ this.startPartition = startPartition;
+ this.endPartition = endPartition;
+ this.context = context;
+ this.numMaps = rssShuffleHandle.getNumMaps();
+ this.shuffleDependency = rssShuffleHandle.getDependency();
+ this.shuffleId = shuffleDependency.shuffleId();
+ this.taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
+ this.shuffleWriteClient = shuffleWriteClient;
+ this.partitionToServerInfos = partitionToServerInfos;
+ this.partitionToExpectBlocks = partitionToExpectBlocks;
+ this.taskIdBitmap = taskIdBitmap;
+ this.readMetrics = readMetrics;
+ this.managerClientSupplier = managerClientSupplier;
+ this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers();
+ this.rssConf = rssConf;
+ this.clientType = clientType;
+ this.idHelper = new DefaultIdHelper(BlockIdLayout.from(rssConf));
+ try {
+ this.keyClass = ClassUtils.getClass(rssShuffleHandle.getDependency().keyClassName());
+ this.isMapCombine = rssShuffleHandle.getDependency().mapSideCombine();
+ this.valueClass =
+ isMapCombine
+ ? ClassUtils.getClass(
+ rssShuffleHandle
+ .getDependency()
+ .combinerClassName()
+ .getOrElse(
+ () -> {
+ throw new RssException(
+ "Can not find combine class even though map combine is enabled!");
+ }))
+ : ClassUtils.getClass(rssShuffleHandle.getDependency().valueClassName());
+ comparator = rssShuffleHandle.getDependency().keyOrdering().getOrElse(() -> null);
+ Aggregator agg = rssShuffleHandle.getDependency().aggregator().getOrElse(() -> null);
+ if (agg != null) {
+ combiner = new SparkCombiner(agg);
+ }
+ } catch (ClassNotFoundException e) {
+ throw new RssException(e);
+ }
+ }
+
+ private void reportUniqueBlocks(Set partitionIds) {
+ for (int partitionId : partitionIds) {
+ Roaring64NavigableMap blockIdBitmap = partitionToExpectBlocks.get(partitionId);
+ Roaring64NavigableMap uniqueBlockIdBitMap = Roaring64NavigableMap.bitmapOf();
+ blockIdBitmap.forEach(
+ blockId -> {
+ long taId = idHelper.getTaskAttemptId(blockId);
+ if (taskIdBitmap.contains(taId)) {
+ uniqueBlockIdBitMap.add(blockId);
+ }
+ });
+ shuffleWriteClient.startSortMerge(
+ new HashSet<>(partitionToServerInfos.get(partitionId)),
+ appId,
+ shuffleId,
+ partitionId,
+ uniqueBlockIdBitMap);
+ }
+ }
+
+ @Override
+ public Iterator> read() {
+ LOG.info("Shuffle read started:" + getReadInfo());
+
+ Iterator> resultIter = new MultiPartitionIterator();
+ resultIter = new InterruptibleIterator>(context, resultIter);
+
+ // resubmit stage and shuffle manager server port are both set
+ if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+ && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) {
+ resultIter =
+ RssFetchFailedIterator.newBuilder()
+ .appId(appId)
+ .shuffleId(shuffleId)
+ .partitionId(startPartition)
+ .stageAttemptId(context.stageAttemptNumber())
+ .managerClientSupplier(managerClientSupplier)
+ .build(resultIter);
+ }
+ return resultIter;
+ }
+
+ private String getReadInfo() {
+ return "appId="
+ + appId
+ + ", shuffleId="
+ + shuffleId
+ + ",taskId="
+ + taskId
+ + ", partitions: ["
+ + startPartition
+ + ", "
+ + endPartition
+ + ")";
+ }
+
+ class MultiPartitionIterator extends AbstractIterator> {
+
+ CompletionIterator, RMRssShuffleDataIterator> dataIterator;
+
+ MultiPartitionIterator() {
+ if (numMaps <= 0) {
+ return;
+ }
+ Set partitionIds = new HashSet<>();
+ for (int partition = startPartition; partition < endPartition; partition++) {
+ if (partitionToExpectBlocks.get(partition).isEmpty()) {
+ LOG.info("{} partition is empty partition", partition);
+ continue;
+ }
+ partitionIds.add(partition);
+ }
+ if (partitionIds.size() == 0) {
+ return;
+ }
+ // report unique blockIds
+ reportUniqueBlocks(partitionIds);
+ RMRecordsReader reader = createRMRecordsReader(partitionIds, partitionToShuffleServers);
+ reader.start();
+ RMRssShuffleDataIterator iter = new RMRssShuffleDataIterator<>(reader);
+ this.dataIterator =
+ CompletionIterator$.MODULE$.apply(
+ iter,
+ () -> {
+ context.taskMetrics().mergeShuffleReadMetrics();
+ return iter.cleanup();
+ });
+ context.addTaskCompletionListener(
+ (taskContext) -> {
+ if (dataIterator != null) {
+ dataIterator.completion();
+ }
+ });
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (dataIterator == null) {
+ return false;
+ }
+ return dataIterator.hasNext();
+ }
+
+ @Override
+ public Product2 next() {
+ Record record = dataIterator.next();
+ Product2 result = new Tuple2(record.getKey(), record.getValue());
+ return result;
+ }
+ }
+
+ @VisibleForTesting
+ public RMRecordsReader createRMRecordsReader(
+ Set partitionIds, Map> serverInfoMap) {
+ return new RMRecordsReader(
+ appId,
+ shuffleId,
+ partitionIds,
+ serverInfoMap,
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ false,
+ combiner,
+ isMapCombine,
+ inc -> {
+ // ShuffleReadMetrics is not thread-safe. Many Fetcher thread will update this value.
+ synchronized (metricsLock) {
+ readMetrics.incRecordsRead(inc);
+ }
+ },
+ clientType);
+ }
+}
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 09d88c1ca7..fd7dec6af7 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -88,6 +88,7 @@
import static org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_REMOTE_MERGE_ENABLE;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
public class RssShuffleWriter extends ShuffleWriter {
@@ -261,20 +262,38 @@ public RssShuffleWriter(
context);
this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId, shuffleHandleInfo);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
- final WriteBufferManager bufferManager =
- new WriteBufferManager(
- shuffleId,
- taskId,
- taskAttemptId,
- bufferOptions,
- rssHandle.getDependency().serializer(),
- context.taskMemoryManager(),
- shuffleWriteMetrics,
- RssSparkConfig.toRssConf(sparkConf),
- this::processShuffleBlockInfos,
- this::getPartitionAssignedServers,
- context.stageAttemptNumber());
- this.bufferManager = bufferManager;
+ if (sparkConf.get(RSS_REMOTE_MERGE_ENABLE)) {
+ final WriteBufferManager bufferManager =
+ new RMWriteBufferManager(
+ shuffleId,
+ taskId,
+ taskAttemptId,
+ bufferOptions,
+ rssHandle.getDependency().serializer(),
+ context.taskMemoryManager(),
+ shuffleWriteMetrics,
+ RssSparkConfig.toRssConf(sparkConf),
+ this::processShuffleBlockInfos,
+ this::getPartitionAssignedServers,
+ context.stageAttemptNumber(),
+ shuffleDependency.keyOrdering().getOrElse(() -> null));
+ this.bufferManager = bufferManager;
+ } else {
+ final WriteBufferManager bufferManager =
+ new WriteBufferManager(
+ shuffleId,
+ taskId,
+ taskAttemptId,
+ bufferOptions,
+ rssHandle.getDependency().serializer(),
+ context.taskMemoryManager(),
+ shuffleWriteMetrics,
+ RssSparkConfig.toRssConf(sparkConf),
+ this::processShuffleBlockInfos,
+ this::getPartitionAssignedServers,
+ context.stageAttemptNumber());
+ this.bufferManager = bufferManager;
+ }
}
@VisibleForTesting
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RMRssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RMRssShuffleReaderTest.java
new file mode 100644
index 0000000000..887eb3a5fd
--- /dev/null
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RMRssShuffleReaderTest.java
@@ -0,0 +1,536 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.reader;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.Iterator;
+import scala.math.Ordering;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import io.netty.buffer.ByteBuf;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleReadMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.shuffle.RssShuffleHandle;
+import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.MockedShuffleServerClient;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.client.record.writer.Combiner;
+import org.apache.uniffle.client.record.writer.SumByKeyCombiner;
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.merger.Merger;
+import org.apache.uniffle.common.merger.Segment;
+import org.apache.uniffle.common.serializer.DynBufferSerOutputStream;
+import org.apache.uniffle.common.serializer.SerOutputStream;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.BlockIdLayout;
+
+import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyMap;
+import static org.mockito.ArgumentMatchers.anySet;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+public class RMRssShuffleReaderTest {
+
+ private static final int RECORDS_NUM = 1009;
+ private static final String APP_ID = "app1";
+ private static final int SHUFFLE_ID = 0;
+ private static final int PARTITION_ID = 0;
+
+ @Test
+ public void testReadShuffleWithoutCombine() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = String.class;
+ final Class valueClass = Integer.class;
+ final Comparator comparator = Ordering.String$.MODULE$;
+ final List serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final RssConf rssConf = new RssConf();
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds = new long[] {blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId)};
+ final Combiner combiner = null;
+
+ // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle
+ // 2.1 mock TaskContext
+ TaskContext context = mock(TaskContext.class);
+ when(context.attemptNumber()).thenReturn(1);
+ when(context.taskAttemptId()).thenReturn(1L);
+ when(context.taskMetrics()).thenReturn(new TaskMetrics());
+ doNothing().when(context).killTaskIfInterrupted();
+ // 2.2 mock ShuffleDependency
+ ShuffleDependency dependency = mock(ShuffleDependency.class);
+ when(dependency.mapSideCombine()).thenReturn(false);
+ when(dependency.aggregator()).thenReturn(Option.empty());
+ when(dependency.keyClassName()).thenReturn(keyClass.getName());
+ when(dependency.valueClassName()).thenReturn(valueClass.getName());
+ when(dependency.keyOrdering()).thenReturn(Option.empty());
+ // 2.3 mock RssShuffleHandler
+ RssShuffleHandle handle = mock(RssShuffleHandle.class);
+ when(handle.getAppId()).thenReturn(APP_ID);
+ when(handle.getDependency()).thenReturn(dependency);
+ when(handle.getShuffleId()).thenReturn(SHUFFLE_ID);
+ when(handle.getNumMaps()).thenReturn(1);
+ when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+
+ // 3 construct remote merge records reader
+ ShuffleReadMetrics readMetrics = new ShuffleReadMetrics();
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APP_ID,
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID),
+ ImmutableMap.of(PARTITION_ID, serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ false,
+ combiner,
+ false,
+ inc -> readMetrics.incRecordsRead(inc));
+ final RMRecordsReader recordsReaderSpy = spy(recordsReader);
+ ByteBuf byteBuf = genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1);
+ ByteBuf[][] buffers = new ByteBuf[][] {{byteBuf}};
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds);
+ doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 construct spark shuffle reader
+ ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class);
+ RMRssShuffleReader shuffleReader =
+ new RMRssShuffleReader(
+ PARTITION_ID,
+ PARTITION_ID + 1,
+ context,
+ handle,
+ writeClient,
+ ImmutableMap.of(PARTITION_ID, serverInfos),
+ ImmutableMap.of(PARTITION_ID, Roaring64NavigableMap.bitmapOf(blockIds)),
+ Roaring64NavigableMap.bitmapOf(taskAttemptId),
+ readMetrics,
+ null,
+ rssConf,
+ ClientType.GRPC.name());
+ RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader);
+ doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap());
+
+ // 5 read and verify result
+ Iterator> iterator = shuffleReaderSpy.read();
+ int index = 0;
+ while (iterator.hasNext()) {
+ Product2 record = iterator.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record._1());
+ assertEquals(SerializerUtils.genData(valueClass, index), record._2());
+ index++;
+ }
+ assertEquals(RECORDS_NUM, index);
+ assertEquals(RECORDS_NUM, readMetrics._recordsRead().value());
+ byteBuf.release();
+ }
+
+ @Test
+ public void testReadShuffleWithCombine() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = String.class;
+ final Class valueClass = Integer.class;
+ final Comparator comparator = Ordering.String$.MODULE$;
+ final List serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final RssConf rssConf = new RssConf();
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds = new long[] {blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId)};
+ final Combiner combiner = new SumByKeyCombiner();
+
+ // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle
+ // 2.1 mock TaskContext
+ TaskContext context = mock(TaskContext.class);
+ when(context.attemptNumber()).thenReturn(1);
+ when(context.taskAttemptId()).thenReturn(1L);
+ when(context.taskMetrics()).thenReturn(new TaskMetrics());
+ doNothing().when(context).killTaskIfInterrupted();
+ // 2.2 mock ShuffleDependency
+ ShuffleDependency dependency = mock(ShuffleDependency.class);
+ when(dependency.mapSideCombine()).thenReturn(false);
+ when(dependency.aggregator()).thenReturn(Option.empty());
+ when(dependency.keyClassName()).thenReturn(keyClass.getName());
+ when(dependency.valueClassName()).thenReturn(valueClass.getName());
+ when(dependency.keyOrdering()).thenReturn(Option.empty());
+ // 2.3 mock RssShuffleHandler
+ RssShuffleHandle handle = mock(RssShuffleHandle.class);
+ when(handle.getAppId()).thenReturn(APP_ID);
+ when(handle.getDependency()).thenReturn(dependency);
+ when(handle.getShuffleId()).thenReturn(SHUFFLE_ID);
+ when(handle.getNumMaps()).thenReturn(1);
+ when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+
+ // 3 construct remote merge records reader
+ ShuffleReadMetrics readMetrics = new ShuffleReadMetrics();
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APP_ID,
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID),
+ ImmutableMap.of(PARTITION_ID, serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ false,
+ combiner,
+ false,
+ inc -> readMetrics.incRecordsRead(inc));
+ final RMRecordsReader recordsReaderSpy = spy(recordsReader);
+ List segments = new ArrayList<>();
+ segments.add(
+ SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM));
+ segments.add(
+ SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM));
+ segments.add(
+ SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM));
+ segments.forEach(segment -> segment.init());
+ SerOutputStream output = new DynBufferSerOutputStream();
+ Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, false);
+ output.close();
+ ByteBuf byteBuf = output.toByteBuf();
+ ByteBuf[][] buffers = new ByteBuf[][] {{byteBuf}};
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds);
+ doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 construct spark shuffle reader
+ ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class);
+ RMRssShuffleReader shuffleReader =
+ new RMRssShuffleReader(
+ PARTITION_ID,
+ PARTITION_ID + 1,
+ context,
+ handle,
+ writeClient,
+ ImmutableMap.of(PARTITION_ID, serverInfos),
+ ImmutableMap.of(PARTITION_ID, Roaring64NavigableMap.bitmapOf(blockIds)),
+ Roaring64NavigableMap.bitmapOf(taskAttemptId),
+ readMetrics,
+ null,
+ rssConf,
+ ClientType.GRPC.name());
+ RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader);
+ doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap());
+
+ // 5 read and verify result
+ Iterator> iterator = shuffleReaderSpy.read();
+ int index = 0;
+ while (iterator.hasNext()) {
+ Product2 record = iterator.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record._1());
+ Object value = SerializerUtils.genData(valueClass, index);
+ Object newValue = value;
+ if (index % 2 == 0) {
+ if (value instanceof IntWritable) {
+ newValue = new IntWritable(((IntWritable) value).get() * 2);
+ } else {
+ newValue = (int) value * 2;
+ }
+ }
+ assertEquals(newValue, record._2());
+ index++;
+ }
+ assertEquals(RECORDS_NUM * 2, index);
+ assertEquals(RECORDS_NUM * 3, readMetrics._recordsRead().value());
+ byteBuf.release();
+ }
+
+ @Test
+ public void testReadMulitPartitionShuffleWithoutCombine() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = String.class;
+ final Class valueClass = Integer.class;
+ final Comparator comparator = Ordering.String$.MODULE$;
+ final List serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final RssConf rssConf = new RssConf();
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds =
+ new long[] {
+ blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID, taskAttemptId),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID + 1, taskAttemptId),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID + 2, taskAttemptId)
+ };
+ final Combiner combiner = null;
+
+ // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle
+ // 2.1 mock TaskContext
+ TaskContext context = mock(TaskContext.class);
+ when(context.attemptNumber()).thenReturn(1);
+ when(context.taskAttemptId()).thenReturn(1L);
+ when(context.taskMetrics()).thenReturn(new TaskMetrics());
+ doNothing().when(context).killTaskIfInterrupted();
+ // 2.2 mock ShuffleDependency
+ ShuffleDependency dependency = mock(ShuffleDependency.class);
+ when(dependency.mapSideCombine()).thenReturn(false);
+ when(dependency.aggregator()).thenReturn(Option.empty());
+ when(dependency.keyClassName()).thenReturn(keyClass.getName());
+ when(dependency.valueClassName()).thenReturn(valueClass.getName());
+ when(dependency.keyOrdering()).thenReturn(Option.empty());
+ // 2.3 mock RssShuffleHandler
+ RssShuffleHandle handle = mock(RssShuffleHandle.class);
+ when(handle.getAppId()).thenReturn(APP_ID);
+ when(handle.getDependency()).thenReturn(dependency);
+ when(handle.getShuffleId()).thenReturn(SHUFFLE_ID);
+ when(handle.getNumMaps()).thenReturn(1);
+ when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+
+ // 3 construct remote merge records reader
+ ShuffleReadMetrics readMetrics = new ShuffleReadMetrics();
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APP_ID,
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2),
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ false,
+ combiner,
+ false,
+ inc -> {
+ synchronized (this) {
+ readMetrics.incRecordsRead(inc);
+ }
+ });
+ final RMRecordsReader recordsReaderSpy = spy(recordsReader);
+ ByteBuf[][] buffers = new ByteBuf[3][2];
+ for (int i = 0; i < 3; i++) {
+ buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1);
+ buffers[i][1] =
+ genSortedRecordBuffer(
+ rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1);
+ }
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(
+ new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2}, buffers, blockIds);
+ doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 construct spark shuffle reader
+ ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class);
+ RMRssShuffleReader shuffleReader =
+ new RMRssShuffleReader(
+ PARTITION_ID,
+ PARTITION_ID + 3,
+ context,
+ handle,
+ writeClient,
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos),
+ ImmutableMap.of(
+ PARTITION_ID,
+ Roaring64NavigableMap.bitmapOf(blockIds[0], blockIds[1]),
+ PARTITION_ID + 1,
+ Roaring64NavigableMap.bitmapOf(blockIds[2], blockIds[3]),
+ PARTITION_ID + 2,
+ Roaring64NavigableMap.bitmapOf(blockIds[4], blockIds[5])),
+ Roaring64NavigableMap.bitmapOf(taskAttemptId),
+ readMetrics,
+ null,
+ rssConf,
+ ClientType.GRPC.name());
+ RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader);
+ doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap());
+
+ // 5 read and verify result
+ Iterator> iterator = shuffleReaderSpy.read();
+ int index = 0;
+ while (iterator.hasNext()) {
+ Product2 record = iterator.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record._1());
+ assertEquals(SerializerUtils.genData(valueClass, index), record._2());
+ index++;
+ }
+ assertEquals(RECORDS_NUM * 6, index);
+ assertEquals(RECORDS_NUM * 6, readMetrics._recordsRead().value());
+ Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release()));
+ }
+
+ @Test
+ public void testReadMulitPartitionShuffleWithCombine() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = String.class;
+ final Class valueClass = Integer.class;
+ final Comparator comparator = Ordering.String$.MODULE$;
+ final List serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final RssConf rssConf = new RssConf();
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds =
+ new long[] {
+ blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID, taskAttemptId),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID + 1, taskAttemptId),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID + 2, taskAttemptId)
+ };
+ final Combiner combiner = new SumByKeyCombiner();
+
+ // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle
+ // 2.1 mock TaskContext
+ TaskContext context = mock(TaskContext.class);
+ when(context.attemptNumber()).thenReturn(1);
+ when(context.taskAttemptId()).thenReturn(1L);
+ when(context.taskMetrics()).thenReturn(new TaskMetrics());
+ doNothing().when(context).killTaskIfInterrupted();
+ // 2.2 mock ShuffleDependency
+ ShuffleDependency dependency = mock(ShuffleDependency.class);
+ when(dependency.mapSideCombine()).thenReturn(false);
+ when(dependency.aggregator()).thenReturn(Option.empty());
+ when(dependency.keyClassName()).thenReturn(keyClass.getName());
+ when(dependency.valueClassName()).thenReturn(valueClass.getName());
+ when(dependency.keyOrdering()).thenReturn(Option.empty());
+ // 2.3 mock RssShuffleHandler
+ RssShuffleHandle handle = mock(RssShuffleHandle.class);
+ when(handle.getAppId()).thenReturn(APP_ID);
+ when(handle.getDependency()).thenReturn(dependency);
+ when(handle.getShuffleId()).thenReturn(SHUFFLE_ID);
+ when(handle.getNumMaps()).thenReturn(1);
+ when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+
+ // 3 construct remote merge records reader
+ ShuffleReadMetrics readMetrics = new ShuffleReadMetrics();
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APP_ID,
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2),
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ false,
+ combiner,
+ false,
+ inc -> {
+ synchronized (this) {
+ readMetrics.incRecordsRead(inc);
+ }
+ });
+ final RMRecordsReader recordsReaderSpy = spy(recordsReader);
+ ByteBuf[][] buffers = new ByteBuf[3][2];
+ for (int i = 0; i < 3; i++) {
+ buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 2);
+ buffers[i][1] =
+ genSortedRecordBuffer(
+ rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 2);
+ }
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(
+ new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2}, buffers, blockIds);
+ doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 construct spark shuffle reader
+ ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class);
+ RMRssShuffleReader shuffleReader =
+ new RMRssShuffleReader(
+ PARTITION_ID,
+ PARTITION_ID + 3,
+ context,
+ handle,
+ writeClient,
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos),
+ ImmutableMap.of(
+ PARTITION_ID,
+ Roaring64NavigableMap.bitmapOf(blockIds[0], blockIds[1]),
+ PARTITION_ID + 1,
+ Roaring64NavigableMap.bitmapOf(blockIds[2], blockIds[3]),
+ PARTITION_ID + 2,
+ Roaring64NavigableMap.bitmapOf(blockIds[4], blockIds[5])),
+ Roaring64NavigableMap.bitmapOf(taskAttemptId),
+ readMetrics,
+ null,
+ rssConf,
+ ClientType.GRPC.name());
+ RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader);
+ doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap());
+
+ // 5 read and verify result
+ Iterator> iterator = shuffleReaderSpy.read();
+ int index = 0;
+ while (iterator.hasNext()) {
+ Product2 record = iterator.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record._1());
+ assertEquals(SerializerUtils.genData(valueClass, index * 2), record._2());
+ index++;
+ }
+ assertEquals(RECORDS_NUM * 6, index);
+ assertEquals(RECORDS_NUM * 12, readMetrics._recordsRead().value());
+ Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release()));
+ }
+}
diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 33980c92f1..6084effb4e 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -77,28 +77,6 @@ default void registerShuffle(
Collections.emptyMap());
}
- default void registerShuffle(
- ShuffleServerInfo shuffleServerInfo,
- String appId,
- int shuffleId,
- List partitionRanges,
- RemoteStorageInfo remoteStorage,
- ShuffleDataDistributionType dataDistributionType,
- int maxConcurrencyPerPartitionToWrite,
- Map properties) {
- registerShuffle(
- shuffleServerInfo,
- appId,
- shuffleId,
- partitionRanges,
- remoteStorage,
- dataDistributionType,
- maxConcurrencyPerPartitionToWrite,
- 0,
- null,
- properties);
- }
-
default void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
String appId,
diff --git a/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkOrderedWordCountTest.scala b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkOrderedWordCountTest.scala
new file mode 100644
index 0000000000..3bcc5cb89b
--- /dev/null
+++ b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkOrderedWordCountTest.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test
+
+import org.apache.spark.shuffle.RssSparkConfig
+import org.apache.spark.sql.SparkSession
+import org.apache.uniffle.common.rpc.ServerType
+import org.apache.uniffle.coordinator.CoordinatorConf
+import org.apache.uniffle.server.ShuffleServerConf
+import org.apache.uniffle.server.buffer.ShuffleBufferType
+import org.apache.uniffle.storage.util.StorageType
+import org.apache.uniffle.test.IntegrationTestBase._
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.{BeforeAll, Test}
+
+import java.io.{File, FileWriter, PrintWriter}
+import java.util
+import scala.collection.JavaConverters.mapAsJavaMap
+import scala.util.Random
+
+object RMSparkOrderedWordCountTest {
+
+ @BeforeAll
+ @throws[Exception]
+ def setupServers(): Unit = {
+ val coordinatorConf: CoordinatorConf = getCoordinatorConf
+ val dynamicConf: util.HashMap[String, String] = new util.HashMap[String, String]()
+ dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key, StorageType.MEMORY_LOCALFILE.name)
+ addDynamicConf(coordinatorConf, dynamicConf)
+ createCoordinatorServer(coordinatorConf)
+ val grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC)
+ val nettyShuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY)
+ grpcShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true)
+ grpcShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST)
+ nettyShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true)
+ nettyShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST)
+ createShuffleServer(grpcShuffleServerConf)
+ createShuffleServer(nettyShuffleServerConf)
+ startServers()
+ }
+}
+
+class RMSparkOrderedWordCountTest extends SparkIntegrationTestBase {
+
+ private[test] val inputPath: String = "word_count_input"
+ private[test] val wordTable: Array[String] = Array("apple",
+ "banana", "fruit", "tomato", "pineapple", "grape", "lemon", "orange", "peach", "mango")
+
+ @Test
+ @throws[Exception]
+ def orderedWordCountTest(): Unit = {
+ run()
+ }
+
+ @throws[Exception]
+ override def run(): Unit = {
+ val fileName = generateTextFile(100)
+ val sparkConf = createSparkConf
+ // lz4 conflict, so use snappy here
+ sparkConf.set("spark.io.compression.codec", "snappy")
+ // 1 Run spark with remote sort rss
+ // 1.1 GRPC
+ val sparkConfWithRemoteSortRss = sparkConf.clone
+ updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRss)
+ updateSparkConfCustomer(sparkConfWithRemoteSortRss)
+ sparkConfWithRemoteSortRss.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true")
+ val rssResult = runSparkApp(sparkConfWithRemoteSortRss, fileName)
+ // 1.2 GRPC_NETTY
+ val sparkConfWithRemoteSortRssNetty = sparkConf.clone
+ updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRssNetty)
+ updateSparkConfCustomer(sparkConfWithRemoteSortRssNetty)
+ sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true")
+ sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_CLIENT_TYPE.key, "GRPC_NETTY")
+ val rssResultNetty = runSparkApp(sparkConfWithRemoteSortRssNetty, fileName)
+
+ // 2 Run original spark
+ val sparkConfOriginal = sparkConf.clone
+ val originalResult = runSparkApp(sparkConfOriginal, fileName)
+
+ // 3 verify
+ assertEquals(originalResult.size(), rssResult.size())
+ assertEquals(originalResult.size(), rssResultNetty.size())
+ import scala.collection.JavaConverters._
+ for ((k, v) <- originalResult.asScala.toMap) {
+ assertEquals(v, rssResult.get(k))
+ assertEquals(v, rssResultNetty.get(k))
+ }
+ }
+
+ @throws[Exception]
+ def generateTextFile(rows: Int): String = {
+ val file = new File(IntegrationTestBase.tempDir, "wordcount.txt")
+ file.createNewFile
+ file.deleteOnExit()
+ val r = Random
+ val writer = new PrintWriter(new FileWriter(file))
+ try for (i <- 0 until rows) {
+ writer.println(wordTable(r.nextInt(wordTable.length)))
+ }
+ finally if (writer != null) writer.close()
+ file.getAbsolutePath
+ }
+
+ override def runTest(spark: SparkSession, fileName: String): util.Map[String, Int] = {
+ val sc = spark.sparkContext
+ val rdd = sc.textFile(fileName)
+ val counts = rdd.flatMap(_.split(" ")).
+ map(w => (w, 1)).
+ reduceByKey(_ + _)
+ .sortBy(_._1)
+ mapAsJavaMap(counts.collectAsMap())
+ }
+}
diff --git a/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkSQLTest.scala b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkSQLTest.scala
new file mode 100644
index 0000000000..396a1c50fe
--- /dev/null
+++ b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkSQLTest.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test
+
+import com.google.common.collect.Maps
+import org.apache.commons.lang3.StringUtils
+import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.RssSparkConfig
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
+//import org.apache.uniffle.common.config.RssBaseConf.RSS_KRYO_REGISTRATION_CLASSES;
+import org.apache.uniffle.common.rpc.ServerType
+import org.apache.uniffle.coordinator.CoordinatorConf
+import org.apache.uniffle.server.ShuffleServerConf
+import org.apache.uniffle.server.buffer.ShuffleBufferType
+import org.apache.uniffle.storage.util.StorageType
+import org.apache.uniffle.test.IntegrationTestBase._
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.{BeforeAll, Test}
+
+import java.io.{File, FileWriter, PrintWriter}
+import java.util
+import java.util.Map
+
+object RMSparkSQLTest {
+
+ @BeforeAll
+ @throws[Exception]
+ def setupServers(): Unit = {
+ val coordinatorConf: CoordinatorConf = getCoordinatorConf
+ val dynamicConf: util.HashMap[String, String] = new util.HashMap[String, String]()
+ dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key, StorageType.MEMORY_LOCALFILE.name)
+ // dynamicConf.put(RSS_KRYO_REGISTRATION_CLASSES.key(),
+ // "org.apache.spark.sql.execution.RowKey,org.apache.spark.sql.execution.IntKey")
+ addDynamicConf(coordinatorConf, dynamicConf)
+ createCoordinatorServer(coordinatorConf)
+ val grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC)
+ val nettyShuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY)
+ grpcShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true)
+ grpcShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST)
+ // grpcShuffleServerConf.set(RSS_KRYO_REGISTRATION_CLASSES,
+ // "org.apache.spark.sql.execution.RowKey,org.apache.spark.sql.execution.IntKey")
+ nettyShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true)
+ nettyShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST)
+ // nettyShuffleServerConf.set(RSS_KRYO_REGISTRATION_CLASSES,
+ // "org.apache.spark.sql.execution.RowKey,org.apache.spark.sql.execution.IntKey")
+ createShuffleServer(grpcShuffleServerConf)
+ createShuffleServer(nettyShuffleServerConf)
+ startServers()
+ }
+}
+
+class RMSparkSQLTest extends SparkIntegrationTestBase {
+
+ @Test
+ @throws[Exception]
+ def sparkSQLTest(): Unit = {
+ run()
+ }
+
+ override def updateSparkConfCustomer(sparkConf: SparkConf): Unit = {
+ sparkConf.set("spark.sql.shuffle.partitions", "4")
+ }
+
+ @throws[Exception]
+ override def run(): Unit = {
+ val fileName = generateTestFile
+ val sparkConf = createSparkConf
+ // lz4 conflict, so use snappy here
+ sparkConf.set("spark.io.compression.codec", "snappy")
+ // sparkConf.set("spark.sql.execution.sortedShuffle.enabled", "true")
+ // 1 Run spark with remote sort rss
+ // 1.1 GRPC
+ val sparkConfWithRemoteSortRss = sparkConf.clone
+ updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRss)
+ updateSparkConfCustomer(sparkConfWithRemoteSortRss)
+ sparkConfWithRemoteSortRss.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true")
+ val rssResult = runSparkApp(sparkConfWithRemoteSortRss, fileName)
+ // 1.2 GRPC_NETTY
+ val sparkConfWithRemoteSortRssNetty = sparkConf.clone
+ updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRssNetty)
+ updateSparkConfCustomer(sparkConfWithRemoteSortRssNetty)
+ sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true")
+ sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_CLIENT_TYPE.key, "GRPC_NETTY")
+ val rssResultNetty = runSparkApp(sparkConfWithRemoteSortRssNetty, fileName)
+
+ // 2 Run original spark
+ val sparkConfOriginal = sparkConf.clone
+ val originalResult = runSparkApp(sparkConfOriginal, fileName)
+
+ // 3 verify
+ assertEquals(originalResult.size(), rssResult.size())
+ assertEquals(originalResult.size(), rssResultNetty.size())
+ import scala.collection.JavaConverters._
+ for ((k, v) <- originalResult.asScala.toMap) {
+ assertEquals(v, rssResult.get(k))
+ assertEquals(v, rssResultNetty.get(k))
+ }
+ }
+
+ @throws[Exception]
+ override def generateTestFile: String = generateCsvFile
+
+ @throws[Exception]
+ protected def generateCsvFile: String = {
+ val rows = 1000
+ val file = new File(IntegrationTestBase.tempDir, "test.csv")
+ file.createNewFile
+ file.deleteOnExit()
+ try {
+ val writer = new PrintWriter(new FileWriter(file))
+ try for (i <- 0 until rows) {
+ writer.println(generateRecord)
+ }
+ finally if (writer != null) writer.close()
+ }
+ file.getAbsolutePath
+ }
+
+ private def generateRecord = {
+ val random = new java.util.Random
+ val ch = ('a' + random.nextInt(26)).toChar
+ val repeats = random.nextInt(10)
+ StringUtils.repeat(ch, repeats) + "," + random.nextInt(100)
+ }
+
+ override def runTest(spark: SparkSession, fileName: String): util.Map[String, Long] = {
+ val df = spark.read.schema("name STRING, age INT").csv(fileName)
+ df.createOrReplaceTempView("people")
+ val queryResult: Dataset[Row] =
+ spark.sql("SELECT name, count(age) FROM people group by name order by name");
+ val result:Map[String, Long] = Maps.newHashMap[String, Long]
+ queryResult.rdd.collect().foreach(
+ row => result.put(row.getString(0), row.getLong(1))
+ )
+ result
+ }
+}
\ No newline at end of file
diff --git a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
index 027b63d9df..a65b13d798 100644
--- a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
@@ -17,15 +17,18 @@
package org.apache.uniffle.server.merge;
+import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectStreamClass;
import java.lang.reflect.Constructor;
-import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
+import java.util.Base64;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
@@ -40,6 +43,7 @@
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.serializer.SerOutputStream;
@@ -148,6 +152,28 @@ public ClassLoader getClassLoader(String label) {
return cachedClassLoader.getOrDefault(label, cachedClassLoader.get(""));
}
+ private Object decode(String encodeString, ClassLoader classLoader) {
+ try {
+ byte[] bytes = Base64.getDecoder().decode(encodeString.substring(1));
+ ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
+ ObjectInputStream ois =
+ new ObjectInputStream(bais) {
+ @Override
+ protected Class> resolveClass(ObjectStreamClass desc)
+ throws IOException, ClassNotFoundException {
+ try {
+ return Class.forName(desc.getName(), false, classLoader);
+ } catch (ClassNotFoundException e) {
+ return super.resolveClass(desc);
+ }
+ }
+ };
+ return ois.readObject();
+ } catch (Exception e) {
+ throw new RssException(e);
+ }
+ }
+
public StatusCode registerShuffle(String appId, int shuffleId, MergeContext mergeContext) {
try {
ClassLoader classLoader = getClassLoader(mergeContext.getMergeClassLoader());
@@ -155,11 +181,15 @@ public StatusCode registerShuffle(String appId, int shuffleId, MergeContext merg
Class vClass = ClassUtils.getClass(classLoader, mergeContext.getValueClass());
Comparator comparator;
if (StringUtils.isNotBlank(mergeContext.getComparatorClass())) {
- Constructor constructor =
- ClassUtils.getClass(classLoader, mergeContext.getComparatorClass())
- .getDeclaredConstructor();
- constructor.setAccessible(true);
- comparator = (Comparator) constructor.newInstance();
+ if (mergeContext.getComparatorClass().startsWith("#")) {
+ comparator = (Comparator) decode(mergeContext.getComparatorClass(), classLoader);
+ } else {
+ Constructor constructor =
+ ClassUtils.getClass(classLoader, mergeContext.getComparatorClass())
+ .getDeclaredConstructor();
+ constructor.setAccessible(true);
+ comparator = (Comparator) constructor.newInstance();
+ }
} else {
comparator = defaultComparator;
}
@@ -179,12 +209,8 @@ public StatusCode registerShuffle(String appId, int shuffleId, MergeContext merg
comparator,
mergeContext.getMergedBlockSize(),
classLoader));
- } catch (ClassNotFoundException
- | InstantiationException
- | IllegalAccessException
- | NoSuchMethodException
- | InvocationTargetException e) {
- LOG.info("Cant register shuffle, caused by ", e);
+ } catch (Throwable e) {
+ LOG.info("Cannot register shuffle, caused by ", e);
removeBuffer(appId, shuffleId);
return StatusCode.INTERNAL_ERROR;
}