diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml index 639864b196..99564b255b 100644 --- a/client-spark/common/pom.xml +++ b/client-spark/common/pom.xml @@ -89,6 +89,12 @@ net.jpountz.lz4 lz4 + + org.apache.uniffle + rss-common + test-jar + test + diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 4a7f653db6..e8998ed431 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -447,6 +447,27 @@ public class RssSparkConfig { + "sequence number, the partition id and the task attempt id.")) .createWithDefault(1048576); + public static final ConfigEntry RSS_REMOTE_MERGE_ENABLE = + createBooleanBuilder( + new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_ENABLE) + .doc("Whether to enable remote merge")) + .createWithDefault(false); + + public static final ConfigEntry RSS_MERGED_BLOCK_SZIE = + createIntegerBuilder( + new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_MERGED_BLOCK_SZIE) + .internal() + .doc("The merged block size.")) + .createWithDefault(RssClientConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT); + + public static final ConfigEntry RSS_REMOTE_MERGE_CLASS_LOADER = + createStringBuilder( + new ConfigBuilder( + SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_CLASS_LOADER) + .internal() + .doc("The class loader label for remote merge")) + .createWithDefault(null); + // spark2 doesn't have this key defined public static final String SPARK_SHUFFLE_COMPRESS_KEY = "spark.shuffle.compress"; diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/SparkCombiner.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/SparkCombiner.java new file mode 100644 index 0000000000..f292fc59d0 --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/SparkCombiner.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.spark.Aggregator; + +import org.apache.uniffle.client.record.Record; +import org.apache.uniffle.client.record.writer.Combiner; + +public class SparkCombiner extends Combiner { + + private Aggregator agg; + + public SparkCombiner(Aggregator agg) { + this.agg = agg; + } + + @Override + public List combineValues(Iterator>> recordIterator) { + List ret = new ArrayList<>(); + while (recordIterator.hasNext()) { + Map.Entry> entry = recordIterator.next(); + List records = entry.getValue(); + Record current = null; + for (Record record : records) { + if (current == null) { + // Handle new Key + current = Record.create(record.getKey(), agg.createCombiner().apply(record.getValue())); + } else { + // Combine the values + C newValue = agg.mergeValue().apply(current.getValue(), record.getValue()); + current.setValue(newValue); + } + } + ret.add(current); + } + return ret; + } + + @Override + public List combineCombiners(Iterator>> recordIterator) { + List ret = new ArrayList<>(); + while (recordIterator.hasNext()) { + Map.Entry> entry = recordIterator.next(); + List records = entry.getValue(); + Record current = null; + for (Record record : records) { + if (current == null) { + // Handle new Key + current = record; + } else { + // Combine the values + C newValue = agg.mergeCombiners().apply(current.getValue(), record.getValue()); + current.setValue(newValue); + } + } + ret.add(current); + } + return ret; + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleDataIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleDataIterator.java new file mode 100644 index 0000000000..dc8a429f3e --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleDataIterator.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.reader; + +import java.io.IOException; + +import scala.collection.AbstractIterator; +import scala.runtime.BoxedUnit; + +import org.apache.uniffle.client.record.Record; +import org.apache.uniffle.client.record.reader.KeyValueReader; +import org.apache.uniffle.client.record.reader.RMRecordsReader; +import org.apache.uniffle.common.exception.RssException; + +public class RMRssShuffleDataIterator extends AbstractIterator> { + + private RMRecordsReader reader; + private KeyValueReader keyValueReader; + + public RMRssShuffleDataIterator(RMRecordsReader reader) { + this.reader = reader; + this.keyValueReader = reader.keyValueReader(); + } + + @Override + public boolean hasNext() { + try { + return this.keyValueReader.hasNext(); + } catch (IOException e) { + throw new RssException(e); + } + } + + @Override + public Record next() { + try { + return this.keyValueReader.next(); + } catch (IOException e) { + throw new RssException(e); + } + } + + public BoxedUnit cleanup() { + reader.close(); + return BoxedUnit.UNIT; + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/RMWriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/RMWriteBufferManager.java new file mode 100644 index 0000000000..7a3c874097 --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/RMWriteBufferManager.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.writer; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Output; +import com.esotericsoftware.kryo.io.UnsafeOutput; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.Serializer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.serializer.kryo.KryoSerializerInstance; +import org.apache.uniffle.common.util.ChecksumUtils; + +public class RMWriteBufferManager extends WriteBufferManager { + + private static final Logger LOG = LoggerFactory.getLogger(RMWriteBufferManager.class); + + private Comparator comparator; + private KryoSerializerInstance instance; + private Kryo serKryo; + private Output serOutput; + + public RMWriteBufferManager( + int shuffleId, + String taskId, + long taskAttemptId, + BufferManagerOptions bufferManagerOptions, + Serializer serializer, + TaskMemoryManager taskMemoryManager, + ShuffleWriteMetrics shuffleWriteMetrics, + RssConf rssConf, + Function, List>> spillFunc, + Function> partitionAssignmentRetrieveFunc, + int stageAttemptNumber, + Comparator comparator) { + super( + shuffleId, + taskId, + taskAttemptId, + bufferManagerOptions, + serializer, + taskMemoryManager, + shuffleWriteMetrics, + rssConf, + spillFunc, + partitionAssignmentRetrieveFunc, + stageAttemptNumber); + this.comparator = comparator; + this.instance = new KryoSerializerInstance(this.rssConf); + this.serKryo = this.instance.borrowKryo(); + this.serOutput = new UnsafeOutput(this.arrayOutputStream); + } + + @Override + public List addRecord(int partitionId, Object key, Object value) { + try { + final long start = System.currentTimeMillis(); + arrayOutputStream.reset(); + this.serKryo.writeClassAndObject(this.serOutput, key); + this.serOutput.flush(); + final int keyLength = arrayOutputStream.size(); + this.serKryo.writeClassAndObject(this.serOutput, value); + this.serOutput.flush(); + serializeTime += System.currentTimeMillis() - start; + byte[] serializedData = arrayOutputStream.getBuf(); + int serializedDataLength = arrayOutputStream.size(); + if (serializedDataLength == 0) { + return null; + } + List shuffleBlockInfos = + addPartitionData( + partitionId, serializedData, keyLength, serializedDataLength - keyLength, start); + // records is a row based semantic, when in columnar shuffle records num should be taken from + // ColumnarBatch + // that is handled by rss shuffle writer implementation + if (isRowBased) { + shuffleWriteMetrics.incRecordsWritten(1L); + } + return shuffleBlockInfos; + } catch (Exception e) { + throw new RssException(e); + } + } + + private List addPartitionData( + int partitionId, byte[] serializedData, int keyLength, int valueLength, long start) { + List singleOrEmptySendingBlocks = + insertIntoBuffer(partitionId, serializedData, keyLength, valueLength); + + // check buffer size > spill threshold + if (usedBytes.get() - inSendListBytes.get() > spillSize) { + LOG.info( + "ShuffleBufferManager spill for buffer size exceeding spill threshold, " + + "usedBytes[{}], inSendListBytes[{}], spill size threshold[{}]", + usedBytes.get(), + inSendListBytes.get(), + spillSize); + List multiSendingBlocks = clear(bufferSpillRatio); + multiSendingBlocks.addAll(singleOrEmptySendingBlocks); + writeTime += System.currentTimeMillis() - start; + return multiSendingBlocks; + } + writeTime += System.currentTimeMillis() - start; + return singleOrEmptySendingBlocks; + } + + private List insertIntoBuffer( + int partitionId, byte[] serializedData, int keyLength, int valueLength) { + int recordLength = keyLength + valueLength; + long required = Math.max(bufferSegmentSize, recordLength); + // Asking memory from task memory manager for the existing writer buffer, + // this may trigger current WriteBufferManager spill method, which will + // make the current write buffer discard. So we have to recheck the buffer existence. + boolean hasRequested = false; + WriterBuffer wb = buffers.get(partitionId); + if (wb != null) { + if (wb.askForMemory(recordLength)) { + requestMemory(required); + hasRequested = true; + } + } + + // hasRequested is not true means spill method was not trigger, + // and we don't have to recheck the buffer existence in this case. + if (hasRequested) { + wb = buffers.get(partitionId); + } + + if (wb != null) { + if (hasRequested) { + usedBytes.addAndGet(required); + } + wb.addRecord(serializedData, keyLength, valueLength); + } else { + // The true of hasRequested means the former partitioned buffer has been flushed, that is + // triggered by the spill operation caused by asking for memory. So it needn't to re-request + // the memory. + if (!hasRequested) { + requestMemory(required); + } + usedBytes.addAndGet(required); + wb = new WriterBuffer(bufferSegmentSize); + wb.addRecord(serializedData, keyLength, valueLength); + buffers.put(partitionId, wb); + } + + if (wb.getMemoryUsed() > bufferSize) { + List sentBlocks = new ArrayList<>(1); + sentBlocks.add(createShuffleBlock(partitionId, wb)); + recordCounter.addAndGet(wb.getRecordCount()); + copyTime += wb.getCopyTime(); + sortRecordTime += wb.getSortTime(); + buffers.remove(partitionId); + if (LOG.isDebugEnabled()) { + LOG.debug( + "Single buffer is full for shuffleId[" + + shuffleId + + "] partition[" + + partitionId + + "] with memoryUsed[" + + wb.getMemoryUsed() + + "], dataLength[" + + wb.getDataLength() + + "]"); + } + return sentBlocks; + } + return Collections.emptyList(); + } + + // transform records to shuffleBlock + @Override + protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb) { + byte[] data = wb.getData(instance, comparator); + final int length = data.length; + final long crc32 = ChecksumUtils.getCrc32(data); + final long blockId = + blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId, taskAttemptId); + blockCounter.incrementAndGet(); + uncompressedDataLen += data.length; + shuffleWriteMetrics.incBytesWritten(data.length); + // add memory to indicate bytes which will be sent to shuffle server + inSendListBytes.addAndGet(wb.getMemoryUsed()); + return new ShuffleBlockInfo( + shuffleId, + partitionId, + blockId, + length, + crc32, + data, + partitionAssignmentRetrieveFunc.apply(partitionId), + length, + wb.getMemoryUsed(), + taskAttemptId); + } + + @Override + public void freeAllMemory() { + super.freeAllMemory(); + if (this.instance != null && this.serKryo != null) { + this.instance.releaseKryo(this.serKryo); + } + if (this.serOutput != null) { + this.serOutput.close(); + } + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index cf4b4bc51c..9143620dac 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -58,40 +58,42 @@ public class WriteBufferManager extends MemoryConsumer { private static final Logger LOG = LoggerFactory.getLogger(WriteBufferManager.class); - private int bufferSize; - private long spillSize; + protected int bufferSize; + protected long spillSize; // allocated bytes from executor memory private AtomicLong allocatedBytes = new AtomicLong(0); // bytes of shuffle data in memory - private AtomicLong usedBytes = new AtomicLong(0); + protected final AtomicLong usedBytes = new AtomicLong(0); // bytes of shuffle data which is in send list - private AtomicLong inSendListBytes = new AtomicLong(0); + protected final AtomicLong inSendListBytes = new AtomicLong(0); /** An atomic counter used to keep track of the number of records */ - private AtomicLong recordCounter = new AtomicLong(0); + protected final AtomicLong recordCounter = new AtomicLong(0); /** An atomic counter used to keep track of the number of blocks */ - private AtomicLong blockCounter = new AtomicLong(0); + protected AtomicLong blockCounter = new AtomicLong(0); // it's part of blockId private Map partitionToSeqNo = Maps.newHashMap(); private long askExecutorMemory; - private int shuffleId; + protected final int shuffleId; private String taskId; - private long taskAttemptId; + protected final long taskAttemptId; private SerializerInstance instance; - private ShuffleWriteMetrics shuffleWriteMetrics; + protected ShuffleWriteMetrics shuffleWriteMetrics; + protected final RssConf rssConf; // cache partition -> records - private Map buffers; + protected final Map buffers; private int serializerBufferSize; - private int bufferSegmentSize; - private long copyTime = 0; - private long serializeTime = 0; + protected final int bufferSegmentSize; + protected long copyTime = 0; + protected long sortRecordTime = 0; + protected long serializeTime = 0; private long compressTime = 0; private long sortTime = 0; - private long writeTime = 0; + protected long writeTime = 0; private long estimateTime = 0; private long requireMemoryTime = 0; private SerializationStream serializeStream; - private WrappedByteArrayOutputStream arrayOutputStream; - private long uncompressedDataLen = 0; + protected final WrappedByteArrayOutputStream arrayOutputStream; + protected long uncompressedDataLen = 0; private long compressedDataLen = 0; private long requireMemoryInterval; private int requireMemoryRetryMax; @@ -100,10 +102,10 @@ public class WriteBufferManager extends MemoryConsumer { private long sendSizeLimit; private boolean memorySpillEnabled; private int memorySpillTimeoutSec; - private boolean isRowBased; - private BlockIdLayout blockIdLayout; - private double bufferSpillRatio; - private Function> partitionAssignmentRetrieveFunc; + protected boolean isRowBased; + protected BlockIdLayout blockIdLayout; + protected double bufferSpillRatio; + protected Function> partitionAssignmentRetrieveFunc; private int stageAttemptNumber; public WriteBufferManager( @@ -174,6 +176,7 @@ public WriteBufferManager( this.taskId = taskId; this.taskAttemptId = taskAttemptId; this.shuffleWriteMetrics = shuffleWriteMetrics; + this.rssConf = rssConf; this.serializerBufferSize = bufferManagerOptions.getSerializerBufferSize(); this.bufferSegmentSize = bufferManagerOptions.getBufferSegmentSize(); this.askExecutorMemory = bufferManagerOptions.getPreAllocatedBufferSize(); @@ -310,6 +313,7 @@ private List insertIntoBuffer( sentBlocks.add(createShuffleBlock(partitionId, wb)); recordCounter.addAndGet(wb.getRecordCount()); copyTime += wb.getCopyTime(); + sortRecordTime += wb.getSortTime(); buffers.remove(partitionId); if (LOG.isDebugEnabled()) { LOG.debug( @@ -449,13 +453,13 @@ protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb) } // it's run in single thread, and is not thread safe - private int getNextSeqNo(int partitionId) { + protected int getNextSeqNo(int partitionId) { return partitionToSeqNo .computeIfAbsent(partitionId, k -> new AtomicInteger(0)) .getAndIncrement(); } - private void requestMemory(long requiredMem) { + protected void requestMemory(long requiredMem) { final long start = System.currentTimeMillis(); if (allocatedBytes.get() - usedBytes.get() < requiredMem) { requestExecutorMemory(requiredMem); @@ -644,6 +648,8 @@ public long getWriteTime() { public String getManagerCostInfo() { return "WriteBufferManager cost copyTime[" + copyTime + + "], sortRecordTime[" + + sortRecordTime + "], writeTime[" + writeTime + "], serializeTime[" diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java index ac6ac9e271..2fe21b98f9 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java @@ -17,23 +17,34 @@ package org.apache.spark.shuffle.writer; +import java.util.ArrayList; +import java.util.Comparator; import java.util.List; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.UnsafeInput; import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.serializer.kryo.KryoSerializerInstance; + public class WriterBuffer { private static final Logger LOG = LoggerFactory.getLogger(WriterBuffer.class); - private long copyTime = 0; - private byte[] buffer; - private int bufferSize; - private int nextOffset = 0; - private List buffers = Lists.newArrayList(); - private int dataLength = 0; - private int memoryUsed = 0; - private long recordCount = 0; + protected long copyTime = 0; + protected byte[] buffer; + protected int bufferSize; + protected int nextOffset = 0; + protected List buffers = Lists.newArrayList(); + protected int dataLength = 0; + protected int memoryUsed = 0; + protected long recordCount = 0; + + private long sortTime = 0; + private final List records = new ArrayList(); public WriterBuffer(int bufferSize) { this.bufferSize = bufferSize; @@ -70,6 +81,13 @@ public void addRecord(byte[] recordBuffer, int length) { recordCount++; } + public void addRecord(byte[] recordBuffer, int keyLength, int valueLength) { + this.addRecord(recordBuffer, keyLength + valueLength); + this.records.add( + new Record( + this.buffers.size(), nextOffset - keyLength - valueLength, keyLength, valueLength)); + } + public boolean askForMemory(long length) { return buffer == null || nextOffset + length > bufferSize; } @@ -88,6 +106,51 @@ public byte[] getData() { return data; } + public byte[] getData(KryoSerializerInstance instance, Comparator comparator) { + if (comparator != null) { + // deserialized key + long start = System.currentTimeMillis(); + Kryo derKryo = null; + try { + derKryo = instance.borrowKryo(); + Input input = new UnsafeInput(); + for (Record record : records) { + byte[] bytes = + record.index == this.buffers.size() ? buffer : this.buffers.get(record.index).buffer; + input.setBuffer(bytes, record.offset, record.keyLength); + record.key = derKryo.readClassAndObject(input); + } + } catch (Throwable e) { + throw new RssException(e); + } finally { + instance.releaseKryo(derKryo); + } + + // sort by key + this.records.sort( + new Comparator() { + @Override + public int compare(Record r1, Record r2) { + return comparator.compare(r1.key, r2.key); + } + }); + sortTime += System.currentTimeMillis() - start; + } + + // write + long start = System.currentTimeMillis(); + byte[] data = new byte[dataLength]; + int offset = 0; + for (Record record : records) { + byte[] bytes = + record.index == buffers.size() ? buffer : buffers.get(record.index).getBuffer(); + System.arraycopy(bytes, record.offset, data, offset, record.keyLength + record.valueLength); + offset += record.keyLength + record.valueLength; + } + copyTime += System.currentTimeMillis() - start; + return data; + } + public int getDataLength() { return dataLength; } @@ -104,6 +167,10 @@ public long getRecordCount() { return recordCount; } + public long getSortTime() { + return sortTime; + } + private static final class WrappedBuffer { byte[] buffer; @@ -122,4 +189,19 @@ public int getSize() { return size; } } + + private static final class Record { + private final int index; + private final int offset; + private final int keyLength; + private final int valueLength; + private Object key = null; + + Record(int keyIndex, int offset, int keyLength, int valueLength) { + this.index = keyIndex; + this.offset = offset; + this.keyLength = keyLength; + this.valueLength = valueLength; + } + } } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index d869c64fe6..d7e95ae2e1 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -17,11 +17,14 @@ package org.apache.uniffle.shuffle.manager; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.ObjectOutputStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -39,6 +42,8 @@ import java.util.stream.Collectors; import scala.Tuple2; +import scala.math.Ordering; +import scala.runtime.AbstractFunction1; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; @@ -49,6 +54,7 @@ import org.apache.hadoop.security.UserGroupInformation; import org.apache.spark.MapOutputTracker; import org.apache.spark.MapOutputTrackerMaster; +import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.SparkException; @@ -100,6 +106,7 @@ import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.common.util.ThreadUtils; +import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.shuffle.BlockIdManager; import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED; @@ -169,6 +176,9 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac protected GrpcServer shuffleManagerServer; protected DataPusher dataPusher; + protected boolean remoteMergeEnable; + protected Map dependencies = new HashMap(); + public RssShuffleManagerBase(SparkConf conf, boolean isDriver) { LOG.info( "Uniffle {} version: {}", this.getClass().getName(), Constants.VERSION_AND_REVISION_SHORT); @@ -327,6 +337,7 @@ public RssShuffleManagerBase(SparkConf conf, boolean isDriver) { rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); this.rssStageResubmitManager = new RssStageResubmitManager(); + this.remoteMergeEnable = sparkConf.get(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE); } @VisibleForTesting @@ -965,7 +976,8 @@ public boolean reassignOnStageResubmit( rssStageResubmitManager.getServerIdBlackList(), stageAttemptId, stageAttemptNumber, - false); + false, + buildMergeContext(dependencies.get(shuffleId))); MutableShuffleHandleInfo shuffleHandleInfo = new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo()); StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = @@ -1082,7 +1094,11 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( "Register the new partition->servers assignment on reassign. {}", newServerToPartitions); registerShuffleServers( - getAppId(), shuffleId, newServerToPartitions, getRemoteStorageInfo()); + getAppId(), + shuffleId, + newServerToPartitions, + getRemoteStorageInfo(), + buildMergeContext(dependencies.get(shuffleId))); } LOG.info( @@ -1248,7 +1264,8 @@ private Set reassignServerForTask( }, stageId, stageAttemptNumber, - reassign); + reassign, + buildMergeContext(dependencies.get(shuffleId))); return replacementsRef.get(); } @@ -1262,7 +1279,8 @@ private Map> requestShuffleAssignment( Function reassignmentHandler, int stageId, int stageAttemptNumber, - boolean reassign) { + boolean reassign, + MergeContext mergeContext) { Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); ClientUtils.validateClientType(clientType); assignmentTags.add(clientType); @@ -1290,7 +1308,11 @@ private Map> requestShuffleAssignment( response = reassignmentHandler.apply(response); } registerShuffleServers( - getAppId(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo()); + getAppId(), + shuffleId, + response.getServerToPartitionRanges(), + getRemoteStorageInfo(), + mergeContext); return response.getPartitionToServers(); } catch (Throwable throwable) { throw new RssException("registerShuffle failed!", throwable); @@ -1306,7 +1328,8 @@ protected Map> requestShuffleAssignment( Set faultyServerIds, int stageId, int stageAttemptNumber, - boolean reassign) { + boolean reassign, + MergeContext mergeContext) { Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); ClientUtils.validateClientType(clientType); assignmentTags.add(clientType); @@ -1339,7 +1362,8 @@ protected Map> requestShuffleAssignment( shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo(), - stageAttemptNumber); + stageAttemptNumber, + mergeContext); return response.getPartitionToServers(); }, retryInterval, @@ -1349,32 +1373,13 @@ protected Map> requestShuffleAssignment( } } - protected Map> requestShuffleAssignment( - int shuffleId, - int partitionNum, - int partitionNumPerRange, - int assignmentShuffleServerNumber, - int estimateTaskConcurrency, - Set faultyServerIds, - int stageAttemptNumber) { - return requestShuffleAssignment( - shuffleId, - partitionNum, - partitionNumPerRange, - assignmentShuffleServerNumber, - estimateTaskConcurrency, - faultyServerIds, - -1, - stageAttemptNumber, - false); - } - protected void registerShuffleServers( String appId, int shuffleId, Map> serverToPartitionRanges, RemoteStorageInfo remoteStorage, - int stageAttemptNumber) { + int stageAttemptNumber, + MergeContext mergeContext) { if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { return; } @@ -1393,7 +1398,7 @@ protected void registerShuffleServers( ShuffleDataDistributionType.NORMAL, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - null, + mergeContext, sparkConfMap); }); LOG.info( @@ -1405,7 +1410,8 @@ protected void registerShuffleServers( String appId, int shuffleId, Map> serverToPartitionRanges, - RemoteStorageInfo remoteStorage) { + RemoteStorageInfo remoteStorage, + MergeContext mergeContext) { if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { return; } @@ -1425,6 +1431,8 @@ protected void registerShuffleServers( remoteStorage, dataDistributionType, maxConcurrencyPerPartitionToWrite, + 0, + mergeContext, sparkConfMap); }); LOG.info( @@ -1538,6 +1546,48 @@ public boolean isValidTask(String taskId) { return !failedTaskIds.contains(taskId); } + protected String encode(Ordering obj) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(obj); + oos.close(); + return "#" + Base64.getEncoder().encodeToString(baos.toByteArray()); + } catch (Exception e) { + throw new RssException(e); + } + } + + protected MergeContext buildMergeContext(ShuffleDependency dependency) { + if (remoteMergeEnable) { + MergeContext mergeContext = + MergeContext.newBuilder() + .setKeyClass(dependency.keyClassName()) + .setValueClass( + dependency.mapSideCombine() + ? dependency.combinerClassName().get() + : dependency.valueClassName()) + .setComparatorClass( + dependency + .keyOrdering() + .map( + new AbstractFunction1, String>() { + @Override + public String apply(Ordering o) { + return encode(o); + } + }) + .get()) + .setMergedBlockSize(sparkConf.get(RssSparkConfig.RSS_MERGED_BLOCK_SZIE)) + .setMergeClassLoader( + sparkConf.get(RssSparkConfig.RSS_REMOTE_MERGE_CLASS_LOADER.key(), "")) + .build(); + return mergeContext; + } else { + return null; + } + } + @VisibleForTesting public void setDataPusher(DataPusher dataPusher) { this.dataPusher = dataPusher; diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/SparkCombinerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/SparkCombinerTest.java new file mode 100644 index 0000000000..24f739ce48 --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/SparkCombinerTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +import java.util.Iterator; + +import scala.runtime.AbstractFunction1; +import scala.runtime.AbstractFunction2; + +import org.apache.spark.Aggregator; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.client.record.Record; +import org.apache.uniffle.client.record.RecordBlob; +import org.apache.uniffle.client.record.RecordBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SparkCombinerTest { + + @Test + public void testSparkCombiner() { + RecordBuffer recordBuffer = new RecordBuffer(-1); + for (int i = 0; i < 5; i++) { + for (int j = 0; j <= i; j++) { + String key = "key" + i; + recordBuffer.addRecord(key, j); + } + } + SparkCombiner combiner = + new SparkCombiner<>( + new Aggregator<>( + new AbstractFunction1() { + @Override + public String apply(Integer v1) { + return v1.toString(); + } + }, + new AbstractFunction2() { + @Override + public String apply(String c, Integer v) { + return Integer.valueOf(Integer.parseInt(c) + v).toString(); + } + }, + new AbstractFunction2() { + @Override + public String apply(String c1, String c2) { + return Integer.valueOf(Integer.parseInt(c1) + Integer.parseInt(c2)).toString(); + } + })); + RecordBlob recordBlob = new RecordBlob(-1); + recordBlob.addRecords(recordBuffer); + recordBlob.combine(combiner, false); + Iterator> newRecords = recordBlob.getResult().iterator(); + int index = 0; + while (newRecords.hasNext()) { + Record record = newRecords.next(); + int expectedValue = 0; + for (int i = 0; i <= index; i++) { + expectedValue += i; + } + assertEquals("Record{key=key" + index + ", value=" + expectedValue + "}", record.toString()); + index++; + } + assertEquals(5, index); + } +} diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 97639d3c44..8056278ac8 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -20,7 +20,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -28,6 +31,8 @@ import java.util.stream.Stream; import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -47,13 +52,19 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.compression.Codec; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.records.RecordsReader; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.BlockIdLayout; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -62,6 +73,8 @@ public class WriteBufferManagerTest { + private static final int RECORD_NUM = 1009; + private WriteBufferManager createManager(SparkConf conf) { Serializer kryoSerializer = new KryoSerializer(conf); TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); @@ -622,4 +635,68 @@ public void addFirstRecordWithLargeSizeTest() { List shuffleBlockInfos2 = wbm.addRecord(1, testKey, testValue2); assertEquals(0, shuffleBlockInfos2.size()); } + + @Test + public void testWriteRemoteMerge() throws Exception { + final SparkConf conf = new SparkConf(); + final RssConf rssConf = RssSparkConfig.toRssConf(conf); + BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); + doReturn(16L * 1024L * 1024L) + .when(mockTaskMemoryManager) + .acquireExecutionMemory(anyLong(), any()); + + Map> partitionToServers = new HashMap<>(); + WriteBufferManager manager = + new RMWriteBufferManager( + 0, + "", + 0, + bufferOptions, + new KryoSerializer(conf), + mockTaskMemoryManager, + new ShuffleWriteMetrics(), + RssSparkConfig.toRssConf(conf), + null, + pid -> partitionToServers.get(pid), + 0, + SerializerUtils.getComparator(String.class)); + List shuffleBlockInfos; + List indexes = new ArrayList<>(); + for (int i = 0; i < RECORD_NUM; i++) { + indexes.add(i); + } + Collections.shuffle(indexes); + for (Integer index : indexes) { + shuffleBlockInfos = + manager.addRecord( + 0, + SerializerUtils.genData(String.class, index), + SerializerUtils.genData(Integer.class, index)); + assertTrue(CollectionUtils.isEmpty(shuffleBlockInfos)); + } + + shuffleBlockInfos = manager.clear(); + assertFalse(CollectionUtils.isEmpty(shuffleBlockInfos)); + + // check blocks + List events = manager.buildBlockEvents(shuffleBlockInfos); + assertEquals(1, events.size()); + List blocks = events.get(0).getShuffleDataInfoList(); + assertEquals(1, blocks.size()); + + ByteBuf buf = blocks.get(0).getData(); + RecordsReader reader = + new RecordsReader<>( + rssConf, SerInputStream.newInputStream(buf), String.class, Integer.class, false, false); + reader.init(); + int index = 0; + while (reader.next()) { + assertEquals(SerializerUtils.genData(String.class, index), reader.getCurrentKey()); + assertEquals(SerializerUtils.genData(Integer.class, index), reader.getCurrentValue()); + index++; + } + reader.close(); + assertEquals(RECORD_NUM, index); + } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java index d4533efaf6..9fd797fc1e 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java @@ -17,18 +17,38 @@ package org.apache.spark.shuffle.writer; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + import scala.reflect.ClassTag$; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Output; +import com.esotericsoftware.kryo.io.UnsafeOutput; +import io.netty.buffer.Unpooled; import org.apache.spark.SparkConf; import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.Serializer; import org.junit.jupiter.api.Test; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.serializer.DeserializationStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerializerUtils; +import org.apache.uniffle.common.serializer.kryo.KryoSerializerInstance; + +import static org.apache.uniffle.common.serializer.SerializerUtils.genData; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; public class WriteBufferTest { + private static final int RECORDS_NUM = 200; + private SparkConf conf = new SparkConf(false); private Serializer kryoSerializer = new KryoSerializer(conf); private WrappedByteArrayOutputStream arrayOutputStream = new WrappedByteArrayOutputStream(32); @@ -37,6 +57,10 @@ public class WriteBufferTest { private byte[] serializedData; private int serializedDataLength; + private int keyLength; + private int valueLength; + private byte[] sortSerializedData; + @Test public void test() { WriterBuffer wb = new WriterBuffer(32); @@ -83,6 +107,50 @@ public void test() { assertEquals(91, wb.getData().length); } + @Test + public void testSortRecords() throws IOException { + arrayOutputStream.reset(); + KryoSerializerInstance instance = new KryoSerializerInstance(new RssConf()); + Kryo serKryo = instance.borrowKryo(); + Output serOutput = new UnsafeOutput(this.arrayOutputStream); + + WriterBuffer wb = new WriterBuffer(1024); + assertEquals(0, wb.getMemoryUsed()); + assertEquals(0, wb.getDataLength()); + + List arrays = new ArrayList<>(); + for (int i = 0; i < RECORDS_NUM; i++) { + arrays.add(i); + } + Collections.shuffle(arrays); + for (int i : arrays) { + String key = (String) SerializerUtils.genData(String.class, i); + int value = (int) SerializerUtils.genData(int.class, i); + serializeData(key, value, serKryo, serOutput); + wb.addRecord(sortSerializedData, keyLength, valueLength); + } + assertEquals(RECORDS_NUM, wb.getRecordCount()); + assertEquals(15 * RECORDS_NUM, wb.getDataLength()); + + byte[] data = wb.getData(instance, SerializerUtils.getComparator(String.class)); + assertEquals(15 * RECORDS_NUM, data.length); + // deserialized + DeserializationStream dStream = + instance.deserializeStream( + SerInputStream.newInputStream(Unpooled.wrappedBuffer(data)), + String.class, + int.class, + false, + false); + dStream.init(); + for (int i = 0; i < RECORDS_NUM; i++) { + assertTrue(dStream.nextRecord()); + assertEquals(genData(String.class, i), dStream.getCurrentKey()); + assertEquals(i, dStream.getCurrentValue()); + } + assertFalse(dStream.nextRecord()); + } + private void serializeData(Object key, Object value) { arrayOutputStream.reset(); serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass())); @@ -91,4 +159,15 @@ private void serializeData(Object key, Object value) { serializedData = arrayOutputStream.getBuf(); serializedDataLength = arrayOutputStream.size(); } + + private void serializeData(Object key, Object value, Kryo serKryo, Output serOutput) { + arrayOutputStream.reset(); + serKryo.writeClassAndObject(serOutput, key); + serOutput.flush(); + keyLength = arrayOutputStream.size(); + serKryo.writeClassAndObject(serOutput, value); + serOutput.flush(); + valueLength = arrayOutputStream.size() - keyLength; + sortSerializedData = arrayOutputStream.getBuf(); + } } diff --git a/client-spark/common/src/test/resources/log4j2.xml b/client-spark/common/src/test/resources/log4j2.xml new file mode 100644 index 0000000000..8db107a47f --- /dev/null +++ b/client-spark/common/src/test/resources/log4j2.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 8e6b8dfca3..a669dd74c0 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -154,7 +154,10 @@ public ShuffleHandle registerShuffle( requiredShuffleServerNumber, estimateTaskConcurrency, rssStageResubmitManager.getServerIdBlackList(), - 0); + -1, + 0, + false, + null); startHeartbeat(); @@ -379,17 +382,6 @@ private Roaring64NavigableMap getShuffleResult( } } - private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) { - Set faultyServerIds = Sets.newHashSet(faultyShuffleServerId); - faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList()); - Map> partitionToServers = - requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, 0); - if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) { - return partitionToServers.get(0).get(0); - } - return null; - } - @Override protected ShuffleWriteClient createShuffleWriteClient() { int unregisterThreadPoolSize = diff --git a/client-spark/spark3/pom.xml b/client-spark/spark3/pom.xml index 94a124d0bc..f29b68598b 100644 --- a/client-spark/spark3/pom.xml +++ b/client-spark/spark3/pom.xml @@ -98,5 +98,17 @@ ${spark.version} provided + + org.apache.uniffle + rss-common + test-jar + test + + + org.apache.uniffle + rss-client + test-jar + test + diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 5e2a941029..b8cf0769bb 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -45,6 +45,7 @@ import org.apache.spark.shuffle.handle.ShuffleHandleInfo; import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; +import org.apache.spark.shuffle.reader.RMRssShuffleReader; import org.apache.spark.shuffle.reader.RssShuffleReader; import org.apache.spark.shuffle.writer.DataPusher; import org.apache.spark.shuffle.writer.RssShuffleWriter; @@ -150,6 +151,7 @@ public ShuffleHandle registerShuffle( return new RssShuffleHandle<>( shuffleId, id.get(), dependency.rdd().getNumPartitions(), dependency, hdlInfoBd); } + dependencies.put(shuffleId, dependency); String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); RemoteStorageInfo defaultRemoteStorage = getDefaultRemoteStorageInfo(sparkConf); @@ -173,7 +175,10 @@ public ShuffleHandle registerShuffle( requiredShuffleServerNumber, estimateTaskConcurrency, rssStageResubmitManager.getServerIdBlackList(), - 0); + -1, + 0, + false, + buildMergeContext(dependency)); startHeartbeat(); shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions()); shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length); @@ -396,6 +401,22 @@ public ShuffleReader getReaderImpl( Configuration readerHadoopConf = RssSparkShuffleUtils.getRemoteStorageHadoopConf(sparkConf, shuffleRemoteStorageInfo); + if (remoteMergeEnable) { + return new RMRssShuffleReader<>( + startPartition, + endPartition, + context, + rssShuffleHandle, + shuffleWriteClient, + shuffleHandleInfo.getAllPartitionServersForReader(), + RssUtils.generatePartitionToBitmap( + blockIdBitmap, startPartition, endPartition, blockIdLayout), + taskIdBitmap, + readMetrics, + managerClientSupplier, + RssSparkConfig.toRssConf(sparkConf), + clientType); + } return new RssShuffleReader( startPartition, endPartition, diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleReader.java new file mode 100644 index 0000000000..ae45b4e1a7 --- /dev/null +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RMRssShuffleReader.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.reader; + +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.AbstractIterator; +import scala.collection.Iterator; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.commons.lang3.ClassUtils; +import org.apache.spark.Aggregator; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleReadMetrics; +import org.apache.spark.shuffle.RssShuffleHandle; +import org.apache.spark.shuffle.ShuffleReader; +import org.apache.spark.shuffle.SparkCombiner; +import org.apache.spark.util.CompletionIterator; +import org.apache.spark.util.CompletionIterator$; +import org.roaringbitmap.longlong.Roaring64NavigableMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.client.api.ShuffleManagerClient; +import org.apache.uniffle.client.api.ShuffleWriteClient; +import org.apache.uniffle.client.record.Record; +import org.apache.uniffle.client.record.reader.RMRecordsReader; +import org.apache.uniffle.client.record.writer.Combiner; +import org.apache.uniffle.client.util.DefaultIdHelper; +import org.apache.uniffle.client.util.RssClientConfig; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssClientConf; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.util.BlockIdLayout; + +public class RMRssShuffleReader implements ShuffleReader { + private static final Logger LOG = LoggerFactory.getLogger(RMRssShuffleReader.class); + private final Map> partitionToShuffleServers; + + private final String appId; + private final int shuffleId; + private final int startPartition; + private final int endPartition; + private final TaskContext context; + private final ShuffleDependency shuffleDependency; + private final int numMaps; + private final String taskId; + private final ShuffleWriteClient shuffleWriteClient; + private final Map> partitionToServerInfos; + private final Map partitionToExpectBlocks; + private final Roaring64NavigableMap taskIdBitmap; + private final ShuffleReadMetrics readMetrics; + private final Supplier managerClientSupplier; + + private final RssConf rssConf; + private final String clientType; + + private final Class keyClass; + private final Class valueClass; + private final boolean isMapCombine; + private Comparator comparator = null; + private Combiner combiner = null; + private DefaultIdHelper idHelper; + + private Object metricsLock = new Object(); + + public RMRssShuffleReader( + int startPartition, + int endPartition, + TaskContext context, + RssShuffleHandle rssShuffleHandle, + ShuffleWriteClient shuffleWriteClient, + Map> partitionToServerInfos, + Map partitionToExpectBlocks, + Roaring64NavigableMap taskIdBitmap, + ShuffleReadMetrics readMetrics, + Supplier managerClientSupplier, + RssConf rssConf, + String clientType) { + this.appId = rssShuffleHandle.getAppId(); + this.startPartition = startPartition; + this.endPartition = endPartition; + this.context = context; + this.numMaps = rssShuffleHandle.getNumMaps(); + this.shuffleDependency = rssShuffleHandle.getDependency(); + this.shuffleId = shuffleDependency.shuffleId(); + this.taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); + this.shuffleWriteClient = shuffleWriteClient; + this.partitionToServerInfos = partitionToServerInfos; + this.partitionToExpectBlocks = partitionToExpectBlocks; + this.taskIdBitmap = taskIdBitmap; + this.readMetrics = readMetrics; + this.managerClientSupplier = managerClientSupplier; + this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers(); + this.rssConf = rssConf; + this.clientType = clientType; + this.idHelper = new DefaultIdHelper(BlockIdLayout.from(rssConf)); + try { + this.keyClass = ClassUtils.getClass(rssShuffleHandle.getDependency().keyClassName()); + this.isMapCombine = rssShuffleHandle.getDependency().mapSideCombine(); + this.valueClass = + isMapCombine + ? ClassUtils.getClass( + rssShuffleHandle + .getDependency() + .combinerClassName() + .getOrElse( + () -> { + throw new RssException( + "Can not find combine class even though map combine is enabled!"); + })) + : ClassUtils.getClass(rssShuffleHandle.getDependency().valueClassName()); + comparator = rssShuffleHandle.getDependency().keyOrdering().getOrElse(() -> null); + Aggregator agg = rssShuffleHandle.getDependency().aggregator().getOrElse(() -> null); + if (agg != null) { + combiner = new SparkCombiner(agg); + } + } catch (ClassNotFoundException e) { + throw new RssException(e); + } + } + + private void reportUniqueBlocks(Set partitionIds) { + for (int partitionId : partitionIds) { + Roaring64NavigableMap blockIdBitmap = partitionToExpectBlocks.get(partitionId); + Roaring64NavigableMap uniqueBlockIdBitMap = Roaring64NavigableMap.bitmapOf(); + blockIdBitmap.forEach( + blockId -> { + long taId = idHelper.getTaskAttemptId(blockId); + if (taskIdBitmap.contains(taId)) { + uniqueBlockIdBitMap.add(blockId); + } + }); + shuffleWriteClient.startSortMerge( + new HashSet<>(partitionToServerInfos.get(partitionId)), + appId, + shuffleId, + partitionId, + uniqueBlockIdBitMap); + } + } + + @Override + public Iterator> read() { + LOG.info("Shuffle read started:" + getReadInfo()); + + Iterator> resultIter = new MultiPartitionIterator(); + resultIter = new InterruptibleIterator>(context, resultIter); + + // resubmit stage and shuffle manager server port are both set + if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false) + && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) { + resultIter = + RssFetchFailedIterator.newBuilder() + .appId(appId) + .shuffleId(shuffleId) + .partitionId(startPartition) + .stageAttemptId(context.stageAttemptNumber()) + .managerClientSupplier(managerClientSupplier) + .build(resultIter); + } + return resultIter; + } + + private String getReadInfo() { + return "appId=" + + appId + + ", shuffleId=" + + shuffleId + + ",taskId=" + + taskId + + ", partitions: [" + + startPartition + + ", " + + endPartition + + ")"; + } + + class MultiPartitionIterator extends AbstractIterator> { + + CompletionIterator, RMRssShuffleDataIterator> dataIterator; + + MultiPartitionIterator() { + if (numMaps <= 0) { + return; + } + Set partitionIds = new HashSet<>(); + for (int partition = startPartition; partition < endPartition; partition++) { + if (partitionToExpectBlocks.get(partition).isEmpty()) { + LOG.info("{} partition is empty partition", partition); + continue; + } + partitionIds.add(partition); + } + if (partitionIds.size() == 0) { + return; + } + // report unique blockIds + reportUniqueBlocks(partitionIds); + RMRecordsReader reader = createRMRecordsReader(partitionIds, partitionToShuffleServers); + reader.start(); + RMRssShuffleDataIterator iter = new RMRssShuffleDataIterator<>(reader); + this.dataIterator = + CompletionIterator$.MODULE$.apply( + iter, + () -> { + context.taskMetrics().mergeShuffleReadMetrics(); + return iter.cleanup(); + }); + context.addTaskCompletionListener( + (taskContext) -> { + if (dataIterator != null) { + dataIterator.completion(); + } + }); + } + + @Override + public boolean hasNext() { + if (dataIterator == null) { + return false; + } + return dataIterator.hasNext(); + } + + @Override + public Product2 next() { + Record record = dataIterator.next(); + Product2 result = new Tuple2(record.getKey(), record.getValue()); + return result; + } + } + + @VisibleForTesting + public RMRecordsReader createRMRecordsReader( + Set partitionIds, Map> serverInfoMap) { + return new RMRecordsReader( + appId, + shuffleId, + partitionIds, + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + false, + combiner, + isMapCombine, + inc -> { + // ShuffleReadMetrics is not thread-safe. Many Fetcher thread will update this value. + synchronized (metricsLock) { + readMetrics.incRecordsRead(inc); + } + }, + clientType); + } +} diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 09d88c1ca7..fd7dec6af7 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -88,6 +88,7 @@ import static org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED; import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_REMOTE_MERGE_ENABLE; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED; public class RssShuffleWriter extends ShuffleWriter { @@ -261,20 +262,38 @@ public RssShuffleWriter( context); this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId, shuffleHandleInfo); BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf); - final WriteBufferManager bufferManager = - new WriteBufferManager( - shuffleId, - taskId, - taskAttemptId, - bufferOptions, - rssHandle.getDependency().serializer(), - context.taskMemoryManager(), - shuffleWriteMetrics, - RssSparkConfig.toRssConf(sparkConf), - this::processShuffleBlockInfos, - this::getPartitionAssignedServers, - context.stageAttemptNumber()); - this.bufferManager = bufferManager; + if (sparkConf.get(RSS_REMOTE_MERGE_ENABLE)) { + final WriteBufferManager bufferManager = + new RMWriteBufferManager( + shuffleId, + taskId, + taskAttemptId, + bufferOptions, + rssHandle.getDependency().serializer(), + context.taskMemoryManager(), + shuffleWriteMetrics, + RssSparkConfig.toRssConf(sparkConf), + this::processShuffleBlockInfos, + this::getPartitionAssignedServers, + context.stageAttemptNumber(), + shuffleDependency.keyOrdering().getOrElse(() -> null)); + this.bufferManager = bufferManager; + } else { + final WriteBufferManager bufferManager = + new WriteBufferManager( + shuffleId, + taskId, + taskAttemptId, + bufferOptions, + rssHandle.getDependency().serializer(), + context.taskMemoryManager(), + shuffleWriteMetrics, + RssSparkConfig.toRssConf(sparkConf), + this::processShuffleBlockInfos, + this::getPartitionAssignedServers, + context.stageAttemptNumber()); + this.bufferManager = bufferManager; + } } @VisibleForTesting diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RMRssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RMRssShuffleReaderTest.java new file mode 100644 index 0000000000..887eb3a5fd --- /dev/null +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RMRssShuffleReaderTest.java @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.reader; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import scala.Option; +import scala.Product2; +import scala.collection.Iterator; +import scala.math.Ordering; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import io.netty.buffer.ByteBuf; +import org.apache.hadoop.io.IntWritable; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleReadMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.shuffle.RssShuffleHandle; +import org.junit.jupiter.api.Test; +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import org.apache.uniffle.client.api.ShuffleServerClient; +import org.apache.uniffle.client.api.ShuffleWriteClient; +import org.apache.uniffle.client.record.reader.MockedShuffleServerClient; +import org.apache.uniffle.client.record.reader.RMRecordsReader; +import org.apache.uniffle.client.record.writer.Combiner; +import org.apache.uniffle.client.record.writer.SumByKeyCombiner; +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.merger.Merger; +import org.apache.uniffle.common.merger.Segment; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; +import org.apache.uniffle.common.serializer.SerializerUtils; +import org.apache.uniffle.common.util.BlockIdLayout; + +import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anySet; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class RMRssShuffleReaderTest { + + private static final int RECORDS_NUM = 1009; + private static final String APP_ID = "app1"; + private static final int SHUFFLE_ID = 0; + private static final int PARTITION_ID = 0; + + @Test + public void testReadShuffleWithoutCombine() throws Exception { + // 1 basic parameter + final Class keyClass = String.class; + final Class valueClass = Integer.class; + final Comparator comparator = Ordering.String$.MODULE$; + final List serverInfos = + Lists.newArrayList(new ShuffleServerInfo("dummy", -1)); + final RssConf rssConf = new RssConf(); + final int taskAttemptId = 0; + BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf); + final long[] blockIds = new long[] {blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId)}; + final Combiner combiner = null; + + // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle + // 2.1 mock TaskContext + TaskContext context = mock(TaskContext.class); + when(context.attemptNumber()).thenReturn(1); + when(context.taskAttemptId()).thenReturn(1L); + when(context.taskMetrics()).thenReturn(new TaskMetrics()); + doNothing().when(context).killTaskIfInterrupted(); + // 2.2 mock ShuffleDependency + ShuffleDependency dependency = mock(ShuffleDependency.class); + when(dependency.mapSideCombine()).thenReturn(false); + when(dependency.aggregator()).thenReturn(Option.empty()); + when(dependency.keyClassName()).thenReturn(keyClass.getName()); + when(dependency.valueClassName()).thenReturn(valueClass.getName()); + when(dependency.keyOrdering()).thenReturn(Option.empty()); + // 2.3 mock RssShuffleHandler + RssShuffleHandle handle = mock(RssShuffleHandle.class); + when(handle.getAppId()).thenReturn(APP_ID); + when(handle.getDependency()).thenReturn(dependency); + when(handle.getShuffleId()).thenReturn(SHUFFLE_ID); + when(handle.getNumMaps()).thenReturn(1); + when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos)); + + // 3 construct remote merge records reader + ShuffleReadMetrics readMetrics = new ShuffleReadMetrics(); + RMRecordsReader recordsReader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID), + ImmutableMap.of(PARTITION_ID, serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + false, + combiner, + false, + inc -> readMetrics.incRecordsRead(inc)); + final RMRecordsReader recordsReaderSpy = spy(recordsReader); + ByteBuf byteBuf = genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1); + ByteBuf[][] buffers = new ByteBuf[][] {{byteBuf}}; + ShuffleServerClient serverClient = + new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds); + doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any()); + + // 4 construct spark shuffle reader + ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class); + RMRssShuffleReader shuffleReader = + new RMRssShuffleReader( + PARTITION_ID, + PARTITION_ID + 1, + context, + handle, + writeClient, + ImmutableMap.of(PARTITION_ID, serverInfos), + ImmutableMap.of(PARTITION_ID, Roaring64NavigableMap.bitmapOf(blockIds)), + Roaring64NavigableMap.bitmapOf(taskAttemptId), + readMetrics, + null, + rssConf, + ClientType.GRPC.name()); + RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader); + doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap()); + + // 5 read and verify result + Iterator> iterator = shuffleReaderSpy.read(); + int index = 0; + while (iterator.hasNext()) { + Product2 record = iterator.next(); + assertEquals(SerializerUtils.genData(keyClass, index), record._1()); + assertEquals(SerializerUtils.genData(valueClass, index), record._2()); + index++; + } + assertEquals(RECORDS_NUM, index); + assertEquals(RECORDS_NUM, readMetrics._recordsRead().value()); + byteBuf.release(); + } + + @Test + public void testReadShuffleWithCombine() throws Exception { + // 1 basic parameter + final Class keyClass = String.class; + final Class valueClass = Integer.class; + final Comparator comparator = Ordering.String$.MODULE$; + final List serverInfos = + Lists.newArrayList(new ShuffleServerInfo("dummy", -1)); + final RssConf rssConf = new RssConf(); + final int taskAttemptId = 0; + BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf); + final long[] blockIds = new long[] {blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId)}; + final Combiner combiner = new SumByKeyCombiner(); + + // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle + // 2.1 mock TaskContext + TaskContext context = mock(TaskContext.class); + when(context.attemptNumber()).thenReturn(1); + when(context.taskAttemptId()).thenReturn(1L); + when(context.taskMetrics()).thenReturn(new TaskMetrics()); + doNothing().when(context).killTaskIfInterrupted(); + // 2.2 mock ShuffleDependency + ShuffleDependency dependency = mock(ShuffleDependency.class); + when(dependency.mapSideCombine()).thenReturn(false); + when(dependency.aggregator()).thenReturn(Option.empty()); + when(dependency.keyClassName()).thenReturn(keyClass.getName()); + when(dependency.valueClassName()).thenReturn(valueClass.getName()); + when(dependency.keyOrdering()).thenReturn(Option.empty()); + // 2.3 mock RssShuffleHandler + RssShuffleHandle handle = mock(RssShuffleHandle.class); + when(handle.getAppId()).thenReturn(APP_ID); + when(handle.getDependency()).thenReturn(dependency); + when(handle.getShuffleId()).thenReturn(SHUFFLE_ID); + when(handle.getNumMaps()).thenReturn(1); + when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos)); + + // 3 construct remote merge records reader + ShuffleReadMetrics readMetrics = new ShuffleReadMetrics(); + RMRecordsReader recordsReader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID), + ImmutableMap.of(PARTITION_ID, serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + false, + combiner, + false, + inc -> readMetrics.incRecordsRead(inc)); + final RMRecordsReader recordsReaderSpy = spy(recordsReader); + List segments = new ArrayList<>(); + segments.add( + SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM)); + segments.add( + SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM)); + segments.add( + SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM)); + segments.forEach(segment -> segment.init()); + SerOutputStream output = new DynBufferSerOutputStream(); + Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, false); + output.close(); + ByteBuf byteBuf = output.toByteBuf(); + ByteBuf[][] buffers = new ByteBuf[][] {{byteBuf}}; + ShuffleServerClient serverClient = + new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds); + doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any()); + + // 4 construct spark shuffle reader + ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class); + RMRssShuffleReader shuffleReader = + new RMRssShuffleReader( + PARTITION_ID, + PARTITION_ID + 1, + context, + handle, + writeClient, + ImmutableMap.of(PARTITION_ID, serverInfos), + ImmutableMap.of(PARTITION_ID, Roaring64NavigableMap.bitmapOf(blockIds)), + Roaring64NavigableMap.bitmapOf(taskAttemptId), + readMetrics, + null, + rssConf, + ClientType.GRPC.name()); + RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader); + doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap()); + + // 5 read and verify result + Iterator> iterator = shuffleReaderSpy.read(); + int index = 0; + while (iterator.hasNext()) { + Product2 record = iterator.next(); + assertEquals(SerializerUtils.genData(keyClass, index), record._1()); + Object value = SerializerUtils.genData(valueClass, index); + Object newValue = value; + if (index % 2 == 0) { + if (value instanceof IntWritable) { + newValue = new IntWritable(((IntWritable) value).get() * 2); + } else { + newValue = (int) value * 2; + } + } + assertEquals(newValue, record._2()); + index++; + } + assertEquals(RECORDS_NUM * 2, index); + assertEquals(RECORDS_NUM * 3, readMetrics._recordsRead().value()); + byteBuf.release(); + } + + @Test + public void testReadMulitPartitionShuffleWithoutCombine() throws Exception { + // 1 basic parameter + final Class keyClass = String.class; + final Class valueClass = Integer.class; + final Comparator comparator = Ordering.String$.MODULE$; + final List serverInfos = + Lists.newArrayList(new ShuffleServerInfo("dummy", -1)); + final RssConf rssConf = new RssConf(); + final int taskAttemptId = 0; + BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf); + final long[] blockIds = + new long[] { + blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId), + blockIdLayout.getBlockId(1, PARTITION_ID, taskAttemptId), + blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId), + blockIdLayout.getBlockId(1, PARTITION_ID + 1, taskAttemptId), + blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId), + blockIdLayout.getBlockId(1, PARTITION_ID + 2, taskAttemptId) + }; + final Combiner combiner = null; + + // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle + // 2.1 mock TaskContext + TaskContext context = mock(TaskContext.class); + when(context.attemptNumber()).thenReturn(1); + when(context.taskAttemptId()).thenReturn(1L); + when(context.taskMetrics()).thenReturn(new TaskMetrics()); + doNothing().when(context).killTaskIfInterrupted(); + // 2.2 mock ShuffleDependency + ShuffleDependency dependency = mock(ShuffleDependency.class); + when(dependency.mapSideCombine()).thenReturn(false); + when(dependency.aggregator()).thenReturn(Option.empty()); + when(dependency.keyClassName()).thenReturn(keyClass.getName()); + when(dependency.valueClassName()).thenReturn(valueClass.getName()); + when(dependency.keyOrdering()).thenReturn(Option.empty()); + // 2.3 mock RssShuffleHandler + RssShuffleHandle handle = mock(RssShuffleHandle.class); + when(handle.getAppId()).thenReturn(APP_ID); + when(handle.getDependency()).thenReturn(dependency); + when(handle.getShuffleId()).thenReturn(SHUFFLE_ID); + when(handle.getNumMaps()).thenReturn(1); + when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos)); + + // 3 construct remote merge records reader + ShuffleReadMetrics readMetrics = new ShuffleReadMetrics(); + RMRecordsReader recordsReader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2), + ImmutableMap.of( + PARTITION_ID, + serverInfos, + PARTITION_ID + 1, + serverInfos, + PARTITION_ID + 2, + serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + false, + combiner, + false, + inc -> { + synchronized (this) { + readMetrics.incRecordsRead(inc); + } + }); + final RMRecordsReader recordsReaderSpy = spy(recordsReader); + ByteBuf[][] buffers = new ByteBuf[3][2]; + for (int i = 0; i < 3; i++) { + buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1); + buffers[i][1] = + genSortedRecordBuffer( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1); + } + ShuffleServerClient serverClient = + new MockedShuffleServerClient( + new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2}, buffers, blockIds); + doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any()); + + // 4 construct spark shuffle reader + ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class); + RMRssShuffleReader shuffleReader = + new RMRssShuffleReader( + PARTITION_ID, + PARTITION_ID + 3, + context, + handle, + writeClient, + ImmutableMap.of( + PARTITION_ID, + serverInfos, + PARTITION_ID + 1, + serverInfos, + PARTITION_ID + 2, + serverInfos), + ImmutableMap.of( + PARTITION_ID, + Roaring64NavigableMap.bitmapOf(blockIds[0], blockIds[1]), + PARTITION_ID + 1, + Roaring64NavigableMap.bitmapOf(blockIds[2], blockIds[3]), + PARTITION_ID + 2, + Roaring64NavigableMap.bitmapOf(blockIds[4], blockIds[5])), + Roaring64NavigableMap.bitmapOf(taskAttemptId), + readMetrics, + null, + rssConf, + ClientType.GRPC.name()); + RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader); + doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap()); + + // 5 read and verify result + Iterator> iterator = shuffleReaderSpy.read(); + int index = 0; + while (iterator.hasNext()) { + Product2 record = iterator.next(); + assertEquals(SerializerUtils.genData(keyClass, index), record._1()); + assertEquals(SerializerUtils.genData(valueClass, index), record._2()); + index++; + } + assertEquals(RECORDS_NUM * 6, index); + assertEquals(RECORDS_NUM * 6, readMetrics._recordsRead().value()); + Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release())); + } + + @Test + public void testReadMulitPartitionShuffleWithCombine() throws Exception { + // 1 basic parameter + final Class keyClass = String.class; + final Class valueClass = Integer.class; + final Comparator comparator = Ordering.String$.MODULE$; + final List serverInfos = + Lists.newArrayList(new ShuffleServerInfo("dummy", -1)); + final RssConf rssConf = new RssConf(); + final int taskAttemptId = 0; + BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf); + final long[] blockIds = + new long[] { + blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId), + blockIdLayout.getBlockId(1, PARTITION_ID, taskAttemptId), + blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId), + blockIdLayout.getBlockId(1, PARTITION_ID + 1, taskAttemptId), + blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId), + blockIdLayout.getBlockId(1, PARTITION_ID + 2, taskAttemptId) + }; + final Combiner combiner = new SumByKeyCombiner(); + + // 2 mock TaskContext, ShuffleDependency, RssShuffleHandle + // 2.1 mock TaskContext + TaskContext context = mock(TaskContext.class); + when(context.attemptNumber()).thenReturn(1); + when(context.taskAttemptId()).thenReturn(1L); + when(context.taskMetrics()).thenReturn(new TaskMetrics()); + doNothing().when(context).killTaskIfInterrupted(); + // 2.2 mock ShuffleDependency + ShuffleDependency dependency = mock(ShuffleDependency.class); + when(dependency.mapSideCombine()).thenReturn(false); + when(dependency.aggregator()).thenReturn(Option.empty()); + when(dependency.keyClassName()).thenReturn(keyClass.getName()); + when(dependency.valueClassName()).thenReturn(valueClass.getName()); + when(dependency.keyOrdering()).thenReturn(Option.empty()); + // 2.3 mock RssShuffleHandler + RssShuffleHandle handle = mock(RssShuffleHandle.class); + when(handle.getAppId()).thenReturn(APP_ID); + when(handle.getDependency()).thenReturn(dependency); + when(handle.getShuffleId()).thenReturn(SHUFFLE_ID); + when(handle.getNumMaps()).thenReturn(1); + when(handle.getPartitionToServers()).thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos)); + + // 3 construct remote merge records reader + ShuffleReadMetrics readMetrics = new ShuffleReadMetrics(); + RMRecordsReader recordsReader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2), + ImmutableMap.of( + PARTITION_ID, + serverInfos, + PARTITION_ID + 1, + serverInfos, + PARTITION_ID + 2, + serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + false, + combiner, + false, + inc -> { + synchronized (this) { + readMetrics.incRecordsRead(inc); + } + }); + final RMRecordsReader recordsReaderSpy = spy(recordsReader); + ByteBuf[][] buffers = new ByteBuf[3][2]; + for (int i = 0; i < 3; i++) { + buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 2); + buffers[i][1] = + genSortedRecordBuffer( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 2); + } + ShuffleServerClient serverClient = + new MockedShuffleServerClient( + new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2}, buffers, blockIds); + doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any()); + + // 4 construct spark shuffle reader + ShuffleWriteClient writeClient = mock(ShuffleWriteClient.class); + RMRssShuffleReader shuffleReader = + new RMRssShuffleReader( + PARTITION_ID, + PARTITION_ID + 3, + context, + handle, + writeClient, + ImmutableMap.of( + PARTITION_ID, + serverInfos, + PARTITION_ID + 1, + serverInfos, + PARTITION_ID + 2, + serverInfos), + ImmutableMap.of( + PARTITION_ID, + Roaring64NavigableMap.bitmapOf(blockIds[0], blockIds[1]), + PARTITION_ID + 1, + Roaring64NavigableMap.bitmapOf(blockIds[2], blockIds[3]), + PARTITION_ID + 2, + Roaring64NavigableMap.bitmapOf(blockIds[4], blockIds[5])), + Roaring64NavigableMap.bitmapOf(taskAttemptId), + readMetrics, + null, + rssConf, + ClientType.GRPC.name()); + RMRssShuffleReader shuffleReaderSpy = spy(shuffleReader); + doReturn(recordsReaderSpy).when(shuffleReaderSpy).createRMRecordsReader(anySet(), anyMap()); + + // 5 read and verify result + Iterator> iterator = shuffleReaderSpy.read(); + int index = 0; + while (iterator.hasNext()) { + Product2 record = iterator.next(); + assertEquals(SerializerUtils.genData(keyClass, index), record._1()); + assertEquals(SerializerUtils.genData(valueClass, index * 2), record._2()); + index++; + } + assertEquals(RECORDS_NUM * 6, index); + assertEquals(RECORDS_NUM * 12, readMetrics._recordsRead().value()); + Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release())); + } +} diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index 33980c92f1..6084effb4e 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -77,28 +77,6 @@ default void registerShuffle( Collections.emptyMap()); } - default void registerShuffle( - ShuffleServerInfo shuffleServerInfo, - String appId, - int shuffleId, - List partitionRanges, - RemoteStorageInfo remoteStorage, - ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite, - Map properties) { - registerShuffle( - shuffleServerInfo, - appId, - shuffleId, - partitionRanges, - remoteStorage, - dataDistributionType, - maxConcurrencyPerPartitionToWrite, - 0, - null, - properties); - } - default void registerShuffle( ShuffleServerInfo shuffleServerInfo, String appId, diff --git a/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkOrderedWordCountTest.scala b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkOrderedWordCountTest.scala new file mode 100644 index 0000000000..3bcc5cb89b --- /dev/null +++ b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkOrderedWordCountTest.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test + +import org.apache.spark.shuffle.RssSparkConfig +import org.apache.spark.sql.SparkSession +import org.apache.uniffle.common.rpc.ServerType +import org.apache.uniffle.coordinator.CoordinatorConf +import org.apache.uniffle.server.ShuffleServerConf +import org.apache.uniffle.server.buffer.ShuffleBufferType +import org.apache.uniffle.storage.util.StorageType +import org.apache.uniffle.test.IntegrationTestBase._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{BeforeAll, Test} + +import java.io.{File, FileWriter, PrintWriter} +import java.util +import scala.collection.JavaConverters.mapAsJavaMap +import scala.util.Random + +object RMSparkOrderedWordCountTest { + + @BeforeAll + @throws[Exception] + def setupServers(): Unit = { + val coordinatorConf: CoordinatorConf = getCoordinatorConf + val dynamicConf: util.HashMap[String, String] = new util.HashMap[String, String]() + dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key, StorageType.MEMORY_LOCALFILE.name) + addDynamicConf(coordinatorConf, dynamicConf) + createCoordinatorServer(coordinatorConf) + val grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC) + val nettyShuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY) + grpcShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true) + grpcShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST) + nettyShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true) + nettyShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST) + createShuffleServer(grpcShuffleServerConf) + createShuffleServer(nettyShuffleServerConf) + startServers() + } +} + +class RMSparkOrderedWordCountTest extends SparkIntegrationTestBase { + + private[test] val inputPath: String = "word_count_input" + private[test] val wordTable: Array[String] = Array("apple", + "banana", "fruit", "tomato", "pineapple", "grape", "lemon", "orange", "peach", "mango") + + @Test + @throws[Exception] + def orderedWordCountTest(): Unit = { + run() + } + + @throws[Exception] + override def run(): Unit = { + val fileName = generateTextFile(100) + val sparkConf = createSparkConf + // lz4 conflict, so use snappy here + sparkConf.set("spark.io.compression.codec", "snappy") + // 1 Run spark with remote sort rss + // 1.1 GRPC + val sparkConfWithRemoteSortRss = sparkConf.clone + updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRss) + updateSparkConfCustomer(sparkConfWithRemoteSortRss) + sparkConfWithRemoteSortRss.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true") + val rssResult = runSparkApp(sparkConfWithRemoteSortRss, fileName) + // 1.2 GRPC_NETTY + val sparkConfWithRemoteSortRssNetty = sparkConf.clone + updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRssNetty) + updateSparkConfCustomer(sparkConfWithRemoteSortRssNetty) + sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true") + sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_CLIENT_TYPE.key, "GRPC_NETTY") + val rssResultNetty = runSparkApp(sparkConfWithRemoteSortRssNetty, fileName) + + // 2 Run original spark + val sparkConfOriginal = sparkConf.clone + val originalResult = runSparkApp(sparkConfOriginal, fileName) + + // 3 verify + assertEquals(originalResult.size(), rssResult.size()) + assertEquals(originalResult.size(), rssResultNetty.size()) + import scala.collection.JavaConverters._ + for ((k, v) <- originalResult.asScala.toMap) { + assertEquals(v, rssResult.get(k)) + assertEquals(v, rssResultNetty.get(k)) + } + } + + @throws[Exception] + def generateTextFile(rows: Int): String = { + val file = new File(IntegrationTestBase.tempDir, "wordcount.txt") + file.createNewFile + file.deleteOnExit() + val r = Random + val writer = new PrintWriter(new FileWriter(file)) + try for (i <- 0 until rows) { + writer.println(wordTable(r.nextInt(wordTable.length))) + } + finally if (writer != null) writer.close() + file.getAbsolutePath + } + + override def runTest(spark: SparkSession, fileName: String): util.Map[String, Int] = { + val sc = spark.sparkContext + val rdd = sc.textFile(fileName) + val counts = rdd.flatMap(_.split(" ")). + map(w => (w, 1)). + reduceByKey(_ + _) + .sortBy(_._1) + mapAsJavaMap(counts.collectAsMap()) + } +} diff --git a/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkSQLTest.scala b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkSQLTest.scala new file mode 100644 index 0000000000..396a1c50fe --- /dev/null +++ b/integration-test/spark3/src/test/scala/org/apache/uniffle/test/RMSparkSQLTest.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test + +import com.google.common.collect.Maps +import org.apache.commons.lang3.StringUtils +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.RssSparkConfig +import org.apache.spark.sql.{Dataset, Row, SparkSession} +//import org.apache.uniffle.common.config.RssBaseConf.RSS_KRYO_REGISTRATION_CLASSES; +import org.apache.uniffle.common.rpc.ServerType +import org.apache.uniffle.coordinator.CoordinatorConf +import org.apache.uniffle.server.ShuffleServerConf +import org.apache.uniffle.server.buffer.ShuffleBufferType +import org.apache.uniffle.storage.util.StorageType +import org.apache.uniffle.test.IntegrationTestBase._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{BeforeAll, Test} + +import java.io.{File, FileWriter, PrintWriter} +import java.util +import java.util.Map + +object RMSparkSQLTest { + + @BeforeAll + @throws[Exception] + def setupServers(): Unit = { + val coordinatorConf: CoordinatorConf = getCoordinatorConf + val dynamicConf: util.HashMap[String, String] = new util.HashMap[String, String]() + dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key, StorageType.MEMORY_LOCALFILE.name) + // dynamicConf.put(RSS_KRYO_REGISTRATION_CLASSES.key(), + // "org.apache.spark.sql.execution.RowKey,org.apache.spark.sql.execution.IntKey") + addDynamicConf(coordinatorConf, dynamicConf) + createCoordinatorServer(coordinatorConf) + val grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC) + val nettyShuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY) + grpcShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true) + grpcShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST) + // grpcShuffleServerConf.set(RSS_KRYO_REGISTRATION_CLASSES, + // "org.apache.spark.sql.execution.RowKey,org.apache.spark.sql.execution.IntKey") + nettyShuffleServerConf.setBoolean(ShuffleServerConf.SERVER_MERGE_ENABLE, true) + nettyShuffleServerConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST) + // nettyShuffleServerConf.set(RSS_KRYO_REGISTRATION_CLASSES, + // "org.apache.spark.sql.execution.RowKey,org.apache.spark.sql.execution.IntKey") + createShuffleServer(grpcShuffleServerConf) + createShuffleServer(nettyShuffleServerConf) + startServers() + } +} + +class RMSparkSQLTest extends SparkIntegrationTestBase { + + @Test + @throws[Exception] + def sparkSQLTest(): Unit = { + run() + } + + override def updateSparkConfCustomer(sparkConf: SparkConf): Unit = { + sparkConf.set("spark.sql.shuffle.partitions", "4") + } + + @throws[Exception] + override def run(): Unit = { + val fileName = generateTestFile + val sparkConf = createSparkConf + // lz4 conflict, so use snappy here + sparkConf.set("spark.io.compression.codec", "snappy") + // sparkConf.set("spark.sql.execution.sortedShuffle.enabled", "true") + // 1 Run spark with remote sort rss + // 1.1 GRPC + val sparkConfWithRemoteSortRss = sparkConf.clone + updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRss) + updateSparkConfCustomer(sparkConfWithRemoteSortRss) + sparkConfWithRemoteSortRss.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true") + val rssResult = runSparkApp(sparkConfWithRemoteSortRss, fileName) + // 1.2 GRPC_NETTY + val sparkConfWithRemoteSortRssNetty = sparkConf.clone + updateSparkConfWithRssGrpc(sparkConfWithRemoteSortRssNetty) + updateSparkConfCustomer(sparkConfWithRemoteSortRssNetty) + sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_REMOTE_MERGE_ENABLE.key, "true") + sparkConfWithRemoteSortRssNetty.set(RssSparkConfig.RSS_CLIENT_TYPE.key, "GRPC_NETTY") + val rssResultNetty = runSparkApp(sparkConfWithRemoteSortRssNetty, fileName) + + // 2 Run original spark + val sparkConfOriginal = sparkConf.clone + val originalResult = runSparkApp(sparkConfOriginal, fileName) + + // 3 verify + assertEquals(originalResult.size(), rssResult.size()) + assertEquals(originalResult.size(), rssResultNetty.size()) + import scala.collection.JavaConverters._ + for ((k, v) <- originalResult.asScala.toMap) { + assertEquals(v, rssResult.get(k)) + assertEquals(v, rssResultNetty.get(k)) + } + } + + @throws[Exception] + override def generateTestFile: String = generateCsvFile + + @throws[Exception] + protected def generateCsvFile: String = { + val rows = 1000 + val file = new File(IntegrationTestBase.tempDir, "test.csv") + file.createNewFile + file.deleteOnExit() + try { + val writer = new PrintWriter(new FileWriter(file)) + try for (i <- 0 until rows) { + writer.println(generateRecord) + } + finally if (writer != null) writer.close() + } + file.getAbsolutePath + } + + private def generateRecord = { + val random = new java.util.Random + val ch = ('a' + random.nextInt(26)).toChar + val repeats = random.nextInt(10) + StringUtils.repeat(ch, repeats) + "," + random.nextInt(100) + } + + override def runTest(spark: SparkSession, fileName: String): util.Map[String, Long] = { + val df = spark.read.schema("name STRING, age INT").csv(fileName) + df.createOrReplaceTempView("people") + val queryResult: Dataset[Row] = + spark.sql("SELECT name, count(age) FROM people group by name order by name"); + val result:Map[String, Long] = Maps.newHashMap[String, Long] + queryResult.rdd.collect().foreach( + row => result.put(row.getString(0), row.getLong(1)) + ) + result + } +} \ No newline at end of file diff --git a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java index 027b63d9df..a65b13d798 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java @@ -17,15 +17,18 @@ package org.apache.uniffle.server.merge; +import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; import java.net.URL; import java.net.URLClassLoader; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Base64; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -40,6 +43,7 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.common.ShuffleDataResult; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.merger.Segment; import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.serializer.SerOutputStream; @@ -148,6 +152,28 @@ public ClassLoader getClassLoader(String label) { return cachedClassLoader.getOrDefault(label, cachedClassLoader.get("")); } + private Object decode(String encodeString, ClassLoader classLoader) { + try { + byte[] bytes = Base64.getDecoder().decode(encodeString.substring(1)); + ByteArrayInputStream bais = new ByteArrayInputStream(bytes); + ObjectInputStream ois = + new ObjectInputStream(bais) { + @Override + protected Class resolveClass(ObjectStreamClass desc) + throws IOException, ClassNotFoundException { + try { + return Class.forName(desc.getName(), false, classLoader); + } catch (ClassNotFoundException e) { + return super.resolveClass(desc); + } + } + }; + return ois.readObject(); + } catch (Exception e) { + throw new RssException(e); + } + } + public StatusCode registerShuffle(String appId, int shuffleId, MergeContext mergeContext) { try { ClassLoader classLoader = getClassLoader(mergeContext.getMergeClassLoader()); @@ -155,11 +181,15 @@ public StatusCode registerShuffle(String appId, int shuffleId, MergeContext merg Class vClass = ClassUtils.getClass(classLoader, mergeContext.getValueClass()); Comparator comparator; if (StringUtils.isNotBlank(mergeContext.getComparatorClass())) { - Constructor constructor = - ClassUtils.getClass(classLoader, mergeContext.getComparatorClass()) - .getDeclaredConstructor(); - constructor.setAccessible(true); - comparator = (Comparator) constructor.newInstance(); + if (mergeContext.getComparatorClass().startsWith("#")) { + comparator = (Comparator) decode(mergeContext.getComparatorClass(), classLoader); + } else { + Constructor constructor = + ClassUtils.getClass(classLoader, mergeContext.getComparatorClass()) + .getDeclaredConstructor(); + constructor.setAccessible(true); + comparator = (Comparator) constructor.newInstance(); + } } else { comparator = defaultComparator; } @@ -179,12 +209,8 @@ public StatusCode registerShuffle(String appId, int shuffleId, MergeContext merg comparator, mergeContext.getMergedBlockSize(), classLoader)); - } catch (ClassNotFoundException - | InstantiationException - | IllegalAccessException - | NoSuchMethodException - | InvocationTargetException e) { - LOG.info("Cant register shuffle, caused by ", e); + } catch (Throwable e) { + LOG.info("Cannot register shuffle, caused by ", e); removeBuffer(appId, shuffleId); return StatusCode.INTERNAL_ERROR; }