diff --git a/auron-build.sh b/auron-build.sh
index 805551290..3d4c31b5e 100755
--- a/auron-build.sh
+++ b/auron-build.sh
@@ -117,10 +117,10 @@ while [[ $# -gt 0 ]]; do
SPARK_VER="$2"
if [ "$SPARK_VER" = "3.0" ] || [ "$SPARK_VER" = "3.1" ] \
|| [ "$SPARK_VER" = "3.2" ] || [ "$SPARK_VER" = "3.3" ] \
- || [ "$SPARK_VER" = "3.4" ] || [ "$SPARK_VER" = "3.5" ]; then
+ || [ "$SPARK_VER" = "3.4" ] || [ "$SPARK_VER" = "3.5" ] || [ "$SPARK_VER" = "4.0" ]; then
echo "Building for Spark $SPARK_VER"
else
- echo "ERROR: Invalid Spark version: $SPARK_VER. The currently supported versions are: 3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5."
+ echo "ERROR: Invalid Spark version: $SPARK_VER. The currently supported versions are: 3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5 / 4.0."
exit 1
fi
shift 2
diff --git a/pom.xml b/pom.xml
index a717d7101..9ad0b4380 100644
--- a/pom.xml
+++ b/pom.xml
@@ -621,6 +621,47 @@
+
+ spark-4.0
+
+ spark-4.0
+ spark-extension-shims-spark4
+ 3.2.10
+ 3.9.9
+ 4.0.1
+
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+ 3.4.1
+
+
+ enforce-java-scala-version
+
+ enforce
+
+
+
+
+ [17,)
+ Spark 4.0 requires JDK 17 or higher to compile!
+
+
+ scalaVersion
+ ^2.13.*
+ Spark 4.0 requires Scala 2.13 and is not compatible with Scala 2.12!
+
+
+
+
+
+
+
+
+
+
jdk-8
diff --git a/spark-extension-shims-spark4/pom.xml b/spark-extension-shims-spark4/pom.xml
new file mode 100644
index 000000000..633cf7cf6
--- /dev/null
+++ b/spark-extension-shims-spark4/pom.xml
@@ -0,0 +1,112 @@
+
+
+
+ 4.0.0
+
+ org.apache.auron
+ auron-parent_${scalaVersion}
+ ${project.version}
+ ../pom.xml
+
+
+ spark-extension-shims-spark4_${scalaVersion}
+ jar
+
+
+ UTF-8
+
+
+
+
+ org.apache.auron
+ auron-common_${scalaVersion}
+ ${project.version}
+
+
+ org.apache.auron
+ spark-extension_${scalaVersion}
+ ${project.version}
+
+
+ org.scala-lang.modules
+ scala-java8-compat_${scalaVersion}
+
+
+ org.apache.spark
+ spark-core_${scalaVersion}
+ provided
+
+
+ org.apache.spark
+ spark-hive_${scalaVersion}
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scalaVersion}
+ provided
+
+
+ org.apache.arrow
+ arrow-c-data
+
+
+ org.apache.arrow
+ arrow-compression
+
+
+ org.apache.arrow
+ arrow-memory-unsafe
+
+
+ org.apache.arrow
+ arrow-vector
+
+
+
+ net.bytebuddy
+ byte-buddy
+
+
+ net.bytebuddy
+ byte-buddy-agent
+
+
+
+ org.apache.auron
+ auron-common_${scalaVersion}
+ ${project.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-core_${scalaVersion}
+ test-jar
+
+
+ org.apache.spark
+ spark-catalyst_${scalaVersion}
+ test-jar
+
+
+ org.apache.spark
+ spark-sql_${scalaVersion}
+ test-jar
+
+
+
diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInjector.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInjector.java
new file mode 100644
index 000000000..00352cc41
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInjector.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron;
+
+import static net.bytebuddy.matcher.ElementMatchers.named;
+
+import net.bytebuddy.ByteBuddy;
+import net.bytebuddy.agent.ByteBuddyAgent;
+import net.bytebuddy.description.type.TypeDescription;
+import net.bytebuddy.dynamic.ClassFileLocator;
+import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
+import net.bytebuddy.implementation.MethodDelegation;
+import net.bytebuddy.pool.TypePool;
+
+public class ForceApplyShuffledHashJoinInjector {
+ public static void inject() {
+ ByteBuddyAgent.install();
+ ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
+ TypeDescription typeDescription = TypePool.Default.of(contextClassLoader)
+ .describe("org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper")
+ .resolve();
+ new ByteBuddy()
+ .redefine(typeDescription, ClassFileLocator.ForClassLoader.of(contextClassLoader))
+ .method(named("forceApplyShuffledHashJoin"))
+ .intercept(MethodDelegation.to(ForceApplyShuffledHashJoinInterceptor.class))
+ .make()
+ .load(contextClassLoader, ClassLoadingStrategy.Default.INJECTION);
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInterceptor.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInterceptor.java
new file mode 100644
index 000000000..8d7d0dd52
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInterceptor.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron;
+
+import net.bytebuddy.implementation.bind.annotation.Argument;
+import net.bytebuddy.implementation.bind.annotation.RuntimeType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ForceApplyShuffledHashJoinInterceptor {
+ private static final Logger logger = LoggerFactory.getLogger(ForceApplyShuffledHashJoinInterceptor.class);
+
+ @RuntimeType
+ public static Object intercept(@Argument(0) Object conf) {
+ logger.debug("calling JoinSelectionHelper.forceApplyShuffledHashJoin() intercepted by auron");
+ return AuronConf.FORCE_SHUFFLED_HASH_JOIN.booleanConf();
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanApplyInterceptor.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanApplyInterceptor.java
new file mode 100644
index 000000000..ea6a6603f
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanApplyInterceptor.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron;
+
+import net.bytebuddy.implementation.bind.annotation.Argument;
+import net.bytebuddy.implementation.bind.annotation.RuntimeType;
+import org.apache.spark.sql.execution.SparkPlan;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ValidateSparkPlanApplyInterceptor {
+ private static final Logger logger = LoggerFactory.getLogger(ValidateSparkPlanApplyInterceptor.class);
+
+ @RuntimeType
+ public static Object intercept(@Argument(0) Object plan) {
+ logger.debug("calling ValidateSparkPlan.apply() intercepted by auron");
+ InterceptedValidateSparkPlan$.MODULE$.validate((SparkPlan) plan);
+ return plan;
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanInjector.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanInjector.java
new file mode 100644
index 000000000..6685b1d7d
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanInjector.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron;
+
+import static net.bytebuddy.matcher.ElementMatchers.named;
+
+import net.bytebuddy.ByteBuddy;
+import net.bytebuddy.agent.ByteBuddyAgent;
+import net.bytebuddy.description.type.TypeDescription;
+import net.bytebuddy.dynamic.ClassFileLocator;
+import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
+import net.bytebuddy.implementation.MethodDelegation;
+import net.bytebuddy.pool.TypePool;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ValidateSparkPlanInjector {
+
+ private static final Logger logger = LoggerFactory.getLogger(ValidateSparkPlanInjector.class);
+ private static boolean injected = false;
+
+ public static synchronized void inject() {
+ if (injected) {
+ logger.warn("ValidateSparkPlan already injected, skipping.");
+ return;
+ }
+ try {
+ ByteBuddyAgent.install();
+ ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
+ TypeDescription typeDescription = TypePool.Default.of(contextClassLoader)
+ .describe("org.apache.spark.sql.execution.adaptive.ValidateSparkPlan$")
+ .resolve();
+ new ByteBuddy()
+ .redefine(typeDescription, ClassFileLocator.ForClassLoader.of(contextClassLoader))
+ .method(named("apply"))
+ .intercept(MethodDelegation.to(ValidateSparkPlanApplyInterceptor.class))
+ .make()
+ .load(contextClassLoader, ClassLoadingStrategy.Default.INJECTION);
+ logger.info("Successfully injected ValidateSparkPlan.");
+ injected = true;
+ } catch (TypePool.Resolution.NoSuchTypeException e) {
+ logger.debug("No such type of ValidateSparkPlan", e);
+ }
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/InterceptedValidateSparkPlan.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/InterceptedValidateSparkPlan.scala
new file mode 100644
index 000000000..61480bd3d
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/InterceptedValidateSparkPlan.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+object InterceptedValidateSparkPlan extends Logging {
+
+ @sparkver("4.0")
+ def validate(plan: SparkPlan): Unit = {
+ import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
+ import org.apache.spark.sql.execution.auron.plan.NativeRenameColumnsBase
+ import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
+ import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec
+ import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
+ import org.apache.spark.sql.catalyst.optimizer.BuildLeft
+ import org.apache.spark.sql.catalyst.optimizer.BuildRight
+
+ plan match {
+ case b: BroadcastHashJoinExec =>
+ val (buildPlan, probePlan) = b.buildSide match {
+ case BuildLeft => (b.left, b.right)
+ case BuildRight => (b.right, b.left)
+ }
+ if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) {
+ validate(buildPlan)
+ }
+ validate(probePlan)
+
+ case b: NativeBroadcastJoinExec => // same as non-native BHJ
+ var (buildPlan, probePlan) = b.buildSide match {
+ case BuildLeft => (b.left, b.right)
+ case BuildRight => (b.right, b.left)
+ }
+ if (buildPlan.isInstanceOf[NativeRenameColumnsBase]) {
+ buildPlan = buildPlan.children.head
+ }
+ if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) {
+ validate(buildPlan)
+ }
+ validate(probePlan)
+
+ case b: BroadcastNestedLoopJoinExec =>
+ val (buildPlan, probePlan) = b.buildSide match {
+ case BuildLeft => (b.left, b.right)
+ case BuildRight => (b.right, b.left)
+ }
+ if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) {
+ validate(buildPlan)
+ }
+ validate(probePlan)
+ case q: BroadcastQueryStageExec => errorOnInvalidBroadcastQueryStage(q)
+ case _ => plan.children.foreach(validate)
+ }
+ }
+
+ @sparkver("4.0")
+ private def errorOnInvalidBroadcastQueryStage(plan: SparkPlan): Unit = {
+ import org.apache.spark.sql.execution.adaptive.InvalidAQEPlanException
+ throw InvalidAQEPlanException("Invalid broadcast query stage", plan)
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
new file mode 100644
index 000000000..a10d0cdae
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import java.io.File
+import java.util.UUID
+
+import org.apache.commons.lang3.reflect.FieldUtils
+import org.apache.spark.OneToOneDependency
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.SparkEnv
+import org.apache.spark.SparkException
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
+import org.apache.spark.sql.{SparkSession, SQLContext}
+import org.apache.spark.sql.auron.AuronConverters.ForceNativeExecutionWrapperBase
+import org.apache.spark.sql.auron.NativeConverters.NativeExprWrapperBase
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.Generator
+import org.apache.spark.sql.catalyst.expressions.Like
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.expressions.StringSplit
+import org.apache.spark.sql.catalyst.expressions.TaggingExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
+import org.apache.spark.sql.catalyst.expressions.aggregate.First
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.CoalescedPartitionSpec
+import org.apache.spark.sql.execution.FileSourceScanExec
+import org.apache.spark.sql.execution.PartialMapperPartitionSpec
+import org.apache.spark.sql.execution.PartialReducerPartitionSpec
+import org.apache.spark.sql.execution.ShuffledRowRDD
+import org.apache.spark.sql.execution.ShufflePartitionSpec
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.UnaryExecNode
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.auron.plan._
+import org.apache.spark.sql.execution.auron.plan.ConvertToNativeExec
+import org.apache.spark.sql.execution.auron.plan.NativeAggBase
+import org.apache.spark.sql.execution.auron.plan.NativeAggBase.AggExecMode
+import org.apache.spark.sql.execution.auron.plan.NativeAggExec
+import org.apache.spark.sql.execution.auron.plan.NativeBroadcastExchangeBase
+import org.apache.spark.sql.execution.auron.plan.NativeBroadcastExchangeExec
+import org.apache.spark.sql.execution.auron.plan.NativeExpandBase
+import org.apache.spark.sql.execution.auron.plan.NativeExpandExec
+import org.apache.spark.sql.execution.auron.plan.NativeFilterBase
+import org.apache.spark.sql.execution.auron.plan.NativeFilterExec
+import org.apache.spark.sql.execution.auron.plan.NativeGenerateBase
+import org.apache.spark.sql.execution.auron.plan.NativeGenerateExec
+import org.apache.spark.sql.execution.auron.plan.NativeGlobalLimitBase
+import org.apache.spark.sql.execution.auron.plan.NativeGlobalLimitExec
+import org.apache.spark.sql.execution.auron.plan.NativeLocalLimitBase
+import org.apache.spark.sql.execution.auron.plan.NativeLocalLimitExec
+import org.apache.spark.sql.execution.auron.plan.NativeOrcScanExec
+import org.apache.spark.sql.execution.auron.plan.NativeParquetInsertIntoHiveTableBase
+import org.apache.spark.sql.execution.auron.plan.NativeParquetInsertIntoHiveTableExec
+import org.apache.spark.sql.execution.auron.plan.NativeParquetScanBase
+import org.apache.spark.sql.execution.auron.plan.NativeParquetScanExec
+import org.apache.spark.sql.execution.auron.plan.NativeProjectBase
+import org.apache.spark.sql.execution.auron.plan.NativeRenameColumnsBase
+import org.apache.spark.sql.execution.auron.plan.NativeShuffleExchangeBase
+import org.apache.spark.sql.execution.auron.plan.NativeShuffleExchangeExec
+import org.apache.spark.sql.execution.auron.plan.NativeSortBase
+import org.apache.spark.sql.execution.auron.plan.NativeSortExec
+import org.apache.spark.sql.execution.auron.plan.NativeTakeOrderedBase
+import org.apache.spark.sql.execution.auron.plan.NativeTakeOrderedExec
+import org.apache.spark.sql.execution.auron.plan.NativeUnionBase
+import org.apache.spark.sql.execution.auron.plan.NativeUnionExec
+import org.apache.spark.sql.execution.auron.plan.NativeWindowBase
+import org.apache.spark.sql.execution.auron.plan.NativeWindowExec
+import org.apache.spark.sql.execution.auron.shuffle.{AuronBlockStoreShuffleReaderBase, AuronRssShuffleManagerBase, RssPartitionWriterBase}
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
+import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec
+import org.apache.spark.sql.execution.joins.auron.plan.NativeShuffledHashJoinExecProvider
+import org.apache.spark.sql.execution.joins.auron.plan.NativeSortMergeJoinExecProvider
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter}
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.FileSegment
+
+import org.apache.auron.{protobuf => pb, sparkver}
+
+class ShimsImpl extends Shims with Logging {
+
+ @sparkver("4.0")
+ override def shimVersion: String = "spark-4.0"
+
+ @sparkver("4.0")
+ override def initExtension(): Unit = {
+ ValidateSparkPlanInjector.inject()
+
+ if (AuronConf.FORCE_SHUFFLED_HASH_JOIN.booleanConf()) {
+ ForceApplyShuffledHashJoinInjector.inject()
+ }
+
+ // disable MultiCommutativeOp suggested in spark3.4+
+ if (shimVersion >= "spark-3.4") {
+ val confName = "spark.sql.analyzer.canonicalization.multiCommutativeOpMemoryOptThreshold"
+ SparkEnv.get.conf.set(confName, Int.MaxValue.toString)
+ }
+ }
+
+ override def createConvertToNativeExec(child: SparkPlan): ConvertToNativeBase =
+ ConvertToNativeExec(child)
+
+ override def createNativeAggExec(
+ execMode: AggExecMode,
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ child: SparkPlan): NativeAggBase =
+ NativeAggExec(
+ execMode,
+ requiredChildDistributionExpressions,
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ child)
+
+ override def createNativeBroadcastExchangeExec(
+ mode: BroadcastMode,
+ child: SparkPlan): NativeBroadcastExchangeBase =
+ NativeBroadcastExchangeExec(mode, child)
+
+ override def createNativeBroadcastJoinExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ outputPartitioning: Partitioning,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ broadcastSide: BroadcastSide): NativeBroadcastJoinBase =
+ NativeBroadcastJoinExec(
+ left,
+ right,
+ outputPartitioning,
+ leftKeys,
+ rightKeys,
+ joinType,
+ broadcastSide)
+
+ override def createNativeSortMergeJoinExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ isSkewJoin: Boolean): NativeSortMergeJoinBase =
+ NativeSortMergeJoinExecProvider.provide(
+ left,
+ right,
+ leftKeys,
+ rightKeys,
+ joinType,
+ isSkewJoin)
+
+ override def createNativeShuffledHashJoinExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ isSkewJoin: Boolean): SparkPlan =
+ NativeShuffledHashJoinExecProvider.provide(
+ left,
+ right,
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide,
+ isSkewJoin)
+
+ override def createNativeExpandExec(
+ projections: Seq[Seq[Expression]],
+ output: Seq[Attribute],
+ child: SparkPlan): NativeExpandBase =
+ NativeExpandExec(projections, output, child)
+
+ override def createNativeFilterExec(condition: Expression, child: SparkPlan): NativeFilterBase =
+ NativeFilterExec(condition, child)
+
+ override def createNativeGenerateExec(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ child: SparkPlan): NativeGenerateBase =
+ NativeGenerateExec(generator, requiredChildOutput, outer, generatorOutput, child)
+
+ override def createNativeGlobalLimitExec(limit: Long, child: SparkPlan): NativeGlobalLimitBase =
+ NativeGlobalLimitExec(limit, child)
+
+ override def createNativeLocalLimitExec(limit: Long, child: SparkPlan): NativeLocalLimitBase =
+ NativeLocalLimitExec(limit, child)
+
+ override def createNativeParquetInsertIntoHiveTableExec(
+ cmd: InsertIntoHiveTable,
+ child: SparkPlan): NativeParquetInsertIntoHiveTableBase =
+ NativeParquetInsertIntoHiveTableExec(cmd, child)
+
+ override def createNativeParquetScanExec(
+ basedFileScan: FileSourceScanExec): NativeParquetScanBase =
+ NativeParquetScanExec(basedFileScan)
+
+ override def createNativeOrcScanExec(basedFileScan: FileSourceScanExec): NativeOrcScanBase =
+ NativeOrcScanExec(basedFileScan)
+
+ override def createNativeProjectExec(
+ projectList: Seq[NamedExpression],
+ child: SparkPlan): NativeProjectBase =
+ NativeProjectExecProvider.provide(projectList, child)
+
+ override def createNativeRenameColumnsExec(
+ child: SparkPlan,
+ newColumnNames: Seq[String]): NativeRenameColumnsBase =
+ NativeRenameColumnsExecProvider.provide(child, newColumnNames)
+
+ override def createNativeShuffleExchangeExec(
+ outputPartitioning: Partitioning,
+ child: SparkPlan,
+ shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase =
+ NativeShuffleExchangeExec(outputPartitioning, child, shuffleOrigin)
+
+ override def createNativeSortExec(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan): NativeSortBase =
+ NativeSortExec(sortOrder, global, child)
+
+ override def createNativeTakeOrderedExec(
+ limit: Long,
+ sortOrder: Seq[SortOrder],
+ child: SparkPlan): NativeTakeOrderedBase =
+ NativeTakeOrderedExec(limit, sortOrder, child)
+
+ override def createNativePartialTakeOrderedExec(
+ limit: Long,
+ sortOrder: Seq[SortOrder],
+ child: SparkPlan,
+ metrics: Map[String, SQLMetric]): NativePartialTakeOrderedBase =
+ NativePartialTakeOrderedExec(limit, sortOrder, child, metrics)
+
+ override def createNativeUnionExec(
+ children: Seq[SparkPlan],
+ output: Seq[Attribute]): NativeUnionBase =
+ NativeUnionExec(children, output)
+
+ override def createNativeWindowExec(
+ windowExpression: Seq[NamedExpression],
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ groupLimit: Option[Int],
+ outputWindowCols: Boolean,
+ child: SparkPlan): NativeWindowBase =
+ NativeWindowExec(windowExpression, partitionSpec, orderSpec, groupLimit, child)
+
+ override def createNativeParquetSinkExec(
+ sparkSession: SparkSession,
+ table: CatalogTable,
+ partition: Map[String, Option[String]],
+ child: SparkPlan,
+ metrics: Map[String, SQLMetric]): NativeParquetSinkBase =
+ NativeParquetSinkExec(sparkSession, table, partition, child, metrics)
+
+ override def getUnderlyingBroadcast(plan: SparkPlan): BroadcastExchangeLike = {
+ plan match {
+ case exec: BroadcastExchangeLike => exec
+ case exec: UnaryExecNode => getUnderlyingBroadcast(exec.child)
+ case exec: BroadcastQueryStageExec => getUnderlyingBroadcast(exec.broadcast)
+ case exec: ReusedExchangeExec => getUnderlyingBroadcast(exec.child)
+ }
+ }
+
+ override def isNative(plan: SparkPlan): Boolean =
+ plan match {
+ case _: NativeSupports => true
+ case plan if isAQEShuffleRead(plan) => isNative(plan.children.head)
+ case plan: QueryStageExec => isNative(plan.plan)
+ case plan: ReusedExchangeExec => isNative(plan.child)
+ case _ => false
+ }
+
+ override def getUnderlyingNativePlan(plan: SparkPlan): NativeSupports = {
+ plan match {
+ case plan: NativeSupports => plan
+ case plan if isAQEShuffleRead(plan) => getUnderlyingNativePlan(plan.children.head)
+ case plan: QueryStageExec => getUnderlyingNativePlan(plan.plan)
+ case plan: ReusedExchangeExec => getUnderlyingNativePlan(plan.child)
+ case _ => throw new RuntimeException("unreachable: plan is not native")
+ }
+ }
+
+ override def executeNative(plan: SparkPlan): NativeRDD = {
+ plan match {
+ case plan: NativeSupports => plan.executeNative()
+ case plan if isAQEShuffleRead(plan) => executeNativeAQEShuffleReader(plan)
+ case plan: QueryStageExec => executeNative(plan.plan)
+ case plan: ReusedExchangeExec => executeNative(plan.child)
+ case _ =>
+ throw new SparkException(s"Underlying plan is not NativeSupports: ${plan}")
+ }
+ }
+
+ override def isQueryStageInput(plan: SparkPlan): Boolean = {
+ plan.isInstanceOf[QueryStageExec]
+ }
+
+ override def isShuffleQueryStageInput(plan: SparkPlan): Boolean = {
+ plan.isInstanceOf[ShuffleQueryStageExec]
+ }
+
+ override def getChildStage(plan: SparkPlan): SparkPlan =
+ plan.asInstanceOf[QueryStageExec].plan
+
+ override def simpleStringWithNodeId(plan: SparkPlan): String = plan.simpleStringWithNodeId()
+
+ override def setLogicalLink(exec: SparkPlan, basedExec: SparkPlan): SparkPlan = {
+ basedExec.logicalLink.foreach(logicalLink => exec.setLogicalLink(logicalLink))
+ exec
+ }
+
+ override def getRDDShuffleReadFull(rdd: RDD[_]): Boolean = true
+
+ override def setRDDShuffleReadFull(rdd: RDD[_], shuffleReadFull: Boolean): Unit = {}
+
+ override def createFileSegment(
+ file: File,
+ offset: Long,
+ length: Long,
+ numRecords: Long): FileSegment = new FileSegment(file, offset, length)
+
+ @sparkver("4.0")
+ override def commit(
+ dep: ShuffleDependency[_, _, _],
+ shuffleBlockResolver: IndexShuffleBlockResolver,
+ tempDataFile: File,
+ mapId: Long,
+ partitionLengths: Array[Long],
+ dataSize: Long,
+ context: TaskContext): MapStatus = {
+
+ val checksums = Array[Long]()
+ shuffleBlockResolver.writeMetadataFileAndCommit(
+ dep.shuffleId,
+ mapId,
+ partitionLengths,
+ checksums,
+ tempDataFile)
+ MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId)
+ }
+
+ override def getRssPartitionWriter(
+ handle: ShuffleHandle,
+ mapId: Int,
+ metrics: ShuffleWriteMetricsReporter,
+ numPartitions: Int): Option[RssPartitionWriterBase] = None
+
+ override def getMapStatus(
+ shuffleServerId: BlockManagerId,
+ partitionLengthMap: Array[Long],
+ mapId: Long): MapStatus =
+ MapStatus.apply(shuffleServerId, partitionLengthMap, mapId)
+
+ override def getShuffleWriteExec(
+ input: pb.PhysicalPlanNode,
+ nativeOutputPartitioning: pb.PhysicalRepartition.Builder): pb.PhysicalPlanNode = {
+
+ if (SparkEnv.get.shuffleManager.isInstanceOf[AuronRssShuffleManagerBase]) {
+ return pb.PhysicalPlanNode
+ .newBuilder()
+ .setRssShuffleWriter(
+ pb.RssShuffleWriterExecNode
+ .newBuilder()
+ .setInput(input)
+ .setOutputPartitioning(nativeOutputPartitioning)
+ .buildPartial()
+ ) // shuffleId is not set at the moment, will be set in ShuffleWriteProcessor
+ .build()
+ }
+
+ pb.PhysicalPlanNode
+ .newBuilder()
+ .setShuffleWriter(
+ pb.ShuffleWriterExecNode
+ .newBuilder()
+ .setInput(input)
+ .setOutputPartitioning(nativeOutputPartitioning)
+ .buildPartial()
+ ) // shuffleId is not set at the moment, will be set in ShuffleWriteProcessor
+ .build()
+ }
+
+ override def convertMoreExprWithFallback(
+ e: Expression,
+ isPruningExpr: Boolean,
+ fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = {
+ e match {
+ case StringSplit(str, pat @ Literal(_, StringType), Literal(-1, IntegerType))
+ // native StringSplit implementation does not support regex, so only most frequently
+ // used cases without regex are supported
+ if Seq(",", ", ", ":", ";", "#", "@", "_", "-", "\\|", "\\.").contains(pat.value) =>
+ val nativePat = pat.value match {
+ case "\\|" => "|"
+ case "\\." => "."
+ case other => other
+ }
+ Some(
+ pb.PhysicalExprNode
+ .newBuilder()
+ .setScalarFunction(
+ pb.PhysicalScalarFunctionNode
+ .newBuilder()
+ .setFun(pb.ScalarFunction.SparkExtFunctions)
+ .setName("StringSplit")
+ .addArgs(NativeConverters.convertExprWithFallback(str, isPruningExpr, fallback))
+ .addArgs(NativeConverters
+ .convertExprWithFallback(Literal(nativePat), isPruningExpr, fallback))
+ .setReturnType(NativeConverters.convertDataType(StringType)))
+ .build())
+
+ case e: TaggingExpression =>
+ Some(NativeConverters.convertExprWithFallback(e.child, isPruningExpr, fallback))
+ case e =>
+ convertPromotePrecision(e, isPruningExpr, fallback) match {
+ case Some(v) => return Some(v)
+ case None =>
+ }
+ convertBloomFilterMightContain(e, isPruningExpr, fallback) match {
+ case Some(v) => return Some(v)
+ case None =>
+ }
+ None
+ }
+ }
+
+ override def getLikeEscapeChar(expr: Expression): Char = {
+ expr.asInstanceOf[Like].escapeChar
+ }
+
+ override def convertMoreAggregateExpr(e: AggregateExpression): Option[pb.PhysicalExprNode] = {
+ assert(getAggregateExpressionFilter(e).isEmpty)
+
+ e.aggregateFunction match {
+ case First(child, ignoresNull) =>
+ val aggExpr = pb.PhysicalAggExprNode
+ .newBuilder()
+ .setReturnType(NativeConverters.convertDataType(e.dataType))
+ .setAggFunction(if (ignoresNull) {
+ pb.AggFunction.FIRST_IGNORES_NULL
+ } else {
+ pb.AggFunction.FIRST
+ })
+ .addChildren(NativeConverters.convertExpr(child))
+ Some(pb.PhysicalExprNode.newBuilder().setAggExpr(aggExpr).build())
+
+ case agg =>
+ convertBloomFilterAgg(agg) match {
+ case Some(aggExpr) =>
+ return Some(
+ pb.PhysicalExprNode
+ .newBuilder()
+ .setAggExpr(aggExpr)
+ .build())
+ case None =>
+ }
+ None
+ }
+ }
+
+ override def getAggregateExpressionFilter(expr: Expression): Option[Expression] = {
+ expr.asInstanceOf[AggregateExpression].filter
+ }
+
+ @sparkver("4.0")
+ private def isAQEShuffleRead(exec: SparkPlan): Boolean = {
+ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
+ exec.isInstanceOf[AQEShuffleReadExec]
+ }
+
+ @sparkver("4.0")
+ private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
+ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
+ import org.apache.spark.sql.execution.CoalescedMapperPartitionSpec
+
+ exec match {
+ case AQEShuffleReadExec(child, _) if isNative(child) =>
+ val shuffledRDD = exec.execute().asInstanceOf[ShuffledRowRDD]
+ val shuffleHandle = shuffledRDD.dependency.shuffleHandle
+
+ val inputRDD = executeNative(child)
+ val nativeShuffle = getUnderlyingNativePlan(child).asInstanceOf[NativeShuffleExchangeExec]
+ val nativeSchema: pb.Schema = nativeShuffle.nativeSchema
+
+ val requiredMetrics = nativeShuffle.readMetrics ++
+ nativeShuffle.metrics.filterKeys(_ == "shuffle_read_total_time")
+ val metrics = MetricNode(
+ requiredMetrics,
+ inputRDD.metrics :: Nil,
+ Some({
+ case ("output_rows", v) =>
+ val tempMetrics = TaskContext.get.taskMetrics().createTempShuffleReadMetrics()
+ new SQLShuffleReadMetricsReporter(tempMetrics, requiredMetrics).incRecordsRead(v)
+ TaskContext.get().taskMetrics().mergeShuffleReadMetrics()
+ case ("elapsed_compute", v) => requiredMetrics("shuffle_read_total_time") += v
+ case _ =>
+ }))
+
+ new NativeRDD(
+ shuffledRDD.sparkContext,
+ metrics,
+ shuffledRDD.partitions,
+ shuffledRDD.partitioner,
+ new OneToOneDependency(shuffledRDD) :: Nil,
+ true,
+ (partition, taskContext) => {
+
+ // use reflection to get partitionSpec because ShuffledRowRDDPartition is private
+ val sqlMetricsReporter = taskContext.taskMetrics().createTempShuffleReadMetrics()
+ val spec = FieldUtils
+ .readDeclaredField(partition, "spec", true)
+ .asInstanceOf[ShufflePartitionSpec]
+ val reader = spec match {
+ case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) =>
+ SparkEnv.get.shuffleManager.getReader(
+ shuffleHandle,
+ startReducerIndex,
+ endReducerIndex,
+ taskContext,
+ sqlMetricsReporter)
+
+ case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) =>
+ SparkEnv.get.shuffleManager.getReader(
+ shuffleHandle,
+ startMapIndex,
+ endMapIndex,
+ 0,
+ numReducers,
+ taskContext,
+ sqlMetricsReporter)
+
+ case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) =>
+ SparkEnv.get.shuffleManager.getReader(
+ shuffleHandle,
+ startMapIndex,
+ endMapIndex,
+ reducerIndex,
+ reducerIndex + 1,
+ taskContext,
+ sqlMetricsReporter)
+
+ case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) =>
+ SparkEnv.get.shuffleManager.getReader(
+ shuffleHandle,
+ mapIndex,
+ mapIndex + 1,
+ startReducerIndex,
+ endReducerIndex,
+ taskContext,
+ sqlMetricsReporter)
+ }
+
+ // store fetch iterator in jni resource before native compute
+ val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
+ JniBridge.resourcesMap.put(
+ jniResourceId,
+ () => {
+ reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc()
+ })
+
+ pb.PhysicalPlanNode
+ .newBuilder()
+ .setIpcReader(
+ pb.IpcReaderExecNode
+ .newBuilder()
+ .setSchema(nativeSchema)
+ .setNumPartitions(shuffledRDD.getNumPartitions)
+ .setIpcProviderResourceId(jniResourceId)
+ .build())
+ .build()
+ })
+ }
+ }
+
+ override def convertMoreSparkPlan(exec: SparkPlan): Option[SparkPlan] = {
+ exec match {
+ case exec if isAQEShuffleRead(exec) && isNative(exec) =>
+ Some(ForceNativeExecutionWrapper(AuronConverters.addRenameColumnsExec(exec)))
+ case _: ReusedExchangeExec if isNative(exec) =>
+ Some(ForceNativeExecutionWrapper(AuronConverters.addRenameColumnsExec(exec)))
+ case _ => None
+ }
+ }
+
+ @sparkver("4.0")
+ override def getSqlContext(sparkPlan: SparkPlan): SQLContext = ???
+
+ override def createNativeExprWrapper(
+ nativeExpr: pb.PhysicalExprNode,
+ dataType: DataType,
+ nullable: Boolean): Expression = {
+ NativeExprWrapper(nativeExpr, dataType, nullable)
+ }
+
+ @sparkver("4.0")
+ override def getPartitionedFile(
+ partitionValues: InternalRow,
+ filePath: String,
+ offset: Long,
+ size: Long): PartitionedFile = {
+ import org.apache.hadoop.fs.Path
+ import org.apache.spark.paths.SparkPath
+ PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size)
+ }
+
+ @sparkver("4.0")
+ override def getMinPartitionNum(sparkSession: SparkSession): Int =
+ sparkSession.sessionState.conf.filesMinPartitionNum
+ .getOrElse(sparkSession.sparkContext.defaultParallelism)
+
+ @sparkver("4.0")
+ private def convertPromotePrecision(
+ e: Expression,
+ isPruningExpr: Boolean,
+ fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = None
+
+ @sparkver("4.0")
+ private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = {
+ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+ agg match {
+ case BloomFilterAggregate(child, estimatedNumItemsExpression, numBitsExpression, _, _) =>
+ // ensure numBits is a power of 2
+ val estimatedNumItems =
+ estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue()
+ val numBits = numBitsExpression.eval().asInstanceOf[Number].longValue()
+ val numBitsNextPowerOf2 = numBits match {
+ case 1 => 1L
+ case n => Integer.highestOneBit(n.toInt - 1) << 1
+ }
+ Some(
+ pb.PhysicalAggExprNode
+ .newBuilder()
+ .setReturnType(NativeConverters.convertDataType(agg.dataType))
+ .setAggFunction(pb.AggFunction.BLOOM_FILTER)
+ .addChildren(NativeConverters.convertExpr(child))
+ .addChildren(NativeConverters.convertExpr(Literal(estimatedNumItems)))
+ .addChildren(NativeConverters.convertExpr(Literal(numBitsNextPowerOf2)))
+ .build())
+ case _ => None
+ }
+ }
+
+ @sparkver("4.0")
+ private def convertBloomFilterMightContain(
+ e: Expression,
+ isPruningExpr: Boolean,
+ fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = {
+ import org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain
+ e match {
+ case e: BloomFilterMightContain =>
+ val uuid = UUID.randomUUID().toString
+ Some(NativeConverters.buildExprNode {
+ _.setBloomFilterMightContainExpr(
+ pb.BloomFilterMightContainExprNode
+ .newBuilder()
+ .setUuid(uuid)
+ .setBloomFilterExpr(NativeConverters
+ .convertExprWithFallback(e.bloomFilterExpression, isPruningExpr, fallback))
+ .setValueExpr(NativeConverters
+ .convertExprWithFallback(e.valueExpression, isPruningExpr, fallback)))
+ })
+ case _ => None
+ }
+ }
+
+ @sparkver("4.0")
+ override def getAdaptiveInputPlan(exec: AdaptiveSparkPlanExec): SparkPlan = {
+ exec.inputPlan
+ }
+}
+
+case class ForceNativeExecutionWrapper(override val child: SparkPlan)
+ extends ForceNativeExecutionWrapperBase(child) {
+
+ @sparkver("4.0")
+ override def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
+
+case class NativeExprWrapper(
+ nativeExpr: pb.PhysicalExprNode,
+ override val dataType: DataType,
+ override val nullable: Boolean)
+ extends NativeExprWrapperBase(nativeExpr, dataType, nullable) {
+
+ @sparkver("4.0")
+ override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy()
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeExec.scala
new file mode 100644
index 000000000..6ffe0a836
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeExec.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class ConvertToNativeExec(override val child: SparkPlan) extends ConvertToNativeBase(child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala
new file mode 100644
index 000000000..bde4bbd9a
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.ExprId
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate.Final
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.execution.auron.plan.NativeAggBase.AggExecMode
+import org.apache.spark.sql.types.BinaryType
+
+import org.apache.auron.sparkver
+
+case class NativeAggExec(
+ execMode: AggExecMode,
+ theRequiredChildDistributionExpressions: Option[Seq[Expression]],
+ override val groupingExpressions: Seq[NamedExpression],
+ override val aggregateExpressions: Seq[AggregateExpression],
+ override val aggregateAttributes: Seq[Attribute],
+ theInitialInputBufferOffset: Int,
+ override val child: SparkPlan)
+ extends NativeAggBase(
+ execMode,
+ theRequiredChildDistributionExpressions,
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ theInitialInputBufferOffset,
+ child)
+ with BaseAggregateExec {
+
+ @sparkver("4.0")
+ override val requiredChildDistributionExpressions: Option[Seq[Expression]] =
+ theRequiredChildDistributionExpressions
+
+ @sparkver("4.0")
+ override val initialInputBufferOffset: Int = theInitialInputBufferOffset
+
+ override def output: Seq[Attribute] =
+ if (aggregateExpressions.map(_.mode).contains(Final)) {
+ groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
+ } else {
+ groupingExpressions.map(_.toAttribute) :+
+ AttributeReference(NativeAggBase.AGG_BUF_COLUMN_NAME, BinaryType, nullable = false)(
+ ExprId.apply(NativeAggBase.AGG_BUF_COLUMN_EXPR_ID))
+ }
+
+ @sparkver("4.0")
+ override def isStreaming: Boolean = false
+
+ @sparkver("4.0")
+ override def numShufflePartitions: Option[Int] = None
+
+ override def resultExpressions: Seq[NamedExpression] = output
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeExec.scala
new file mode 100644
index 000000000..d9fccfe43
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeExec.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import java.util.UUID
+
+import org.apache.spark.broadcast
+import org.apache.spark.sql.auron.NativeSupports
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeBroadcastExchangeExec(mode: BroadcastMode, override val child: SparkPlan)
+ extends NativeBroadcastExchangeBase(mode, child)
+ with NativeSupports {
+
+ override val getRunId: UUID = UUID.randomUUID()
+
+ override def runtimeStatistics: Statistics = {
+ val dataSize = metrics("dataSize").value
+ val rowCount = metrics("numOutputRows").value
+ Statistics(dataSize, Some(rowCount))
+ }
+
+ @transient
+ override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = {
+ relationFuturePromise.future
+ }
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeExpandExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeExpandExec.scala
new file mode 100644
index 000000000..002249154
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeExpandExec.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeExpandExec(
+ projections: Seq[Seq[Expression]],
+ override val output: Seq[Attribute],
+ override val child: SparkPlan)
+ extends NativeExpandBase(projections, output, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeFilterExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeFilterExec.scala
new file mode 100644
index 000000000..0f72420fb
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeFilterExec.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeFilterExec(condition: Expression, override val child: SparkPlan)
+ extends NativeFilterBase(condition, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGenerateExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGenerateExec.scala
new file mode 100644
index 000000000..8c9c02c5e
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGenerateExec.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.Generator
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeGenerateExec(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ override val child: SparkPlan)
+ extends NativeGenerateBase(generator, requiredChildOutput, outer, generatorOutput, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala
new file mode 100644
index 000000000..1dea592e2
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeGlobalLimitExec(limit: Long, override val child: SparkPlan)
+ extends NativeGlobalLimitBase(limit, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala
new file mode 100644
index 000000000..db6f5a0f6
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeLocalLimitExec(limit: Long, override val child: SparkPlan)
+ extends NativeLocalLimitBase(limit, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcScanExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcScanExec.scala
new file mode 100644
index 000000000..06a63f6e9
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcScanExec.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.execution.FileSourceScanExec
+
+case class NativeOrcScanExec(basedFileScan: FileSourceScanExec)
+ extends NativeOrcScanBase(basedFileScan) {
+
+ override def simpleString(maxFields: Int): String =
+ s"$nodeName (${basedFileScan.simpleString(maxFields)})"
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetInsertIntoHiveTableExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetInsertIntoHiveTableExec.scala
new file mode 100644
index 000000000..3c5b7be4c
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetInsertIntoHiveTableExec.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.auron.Shims
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+
+import org.apache.auron.sparkver
+
+case class NativeParquetInsertIntoHiveTableExec(
+ cmd: InsertIntoHiveTable,
+ override val child: SparkPlan)
+ extends NativeParquetInsertIntoHiveTableBase(cmd, child) {
+
+ @sparkver("4.0")
+ override protected def getInsertIntoHiveTableCommand(
+ table: CatalogTable,
+ partition: Map[String, Option[String]],
+ query: LogicalPlan,
+ overwrite: Boolean,
+ ifPartitionNotExists: Boolean,
+ outputColumnNames: Seq[String],
+ metrics: Map[String, SQLMetric]): InsertIntoHiveTable = {
+ new AuronInsertIntoHiveTable40(
+ table,
+ partition,
+ query,
+ overwrite,
+ ifPartitionNotExists,
+ outputColumnNames,
+ metrics)
+ }
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ @sparkver("4.0")
+ class AuronInsertIntoHiveTable40(
+ table: CatalogTable,
+ partition: Map[String, Option[String]],
+ query: LogicalPlan,
+ overwrite: Boolean,
+ ifPartitionNotExists: Boolean,
+ outputColumnNames: Seq[String],
+ outerMetrics: Map[String, SQLMetric])
+ extends {
+ private val insertIntoHiveTable = InsertIntoHiveTable(
+ table,
+ partition,
+ query,
+ overwrite,
+ ifPartitionNotExists,
+ outputColumnNames)
+ private val initPartitionColumns = insertIntoHiveTable.partitionColumns
+ private val initBucketSpec = insertIntoHiveTable.bucketSpec
+ private val initOptions = insertIntoHiveTable.options
+ private val initFileFormat = insertIntoHiveTable.fileFormat
+ private val initHiveTmpPath = insertIntoHiveTable.hiveTmpPath
+
+ }
+ with InsertIntoHiveTable(
+ table,
+ partition,
+ query,
+ overwrite,
+ ifPartitionNotExists,
+ outputColumnNames,
+ initPartitionColumns,
+ initBucketSpec,
+ initOptions,
+ initFileFormat,
+ initHiveTmpPath) {
+
+ override lazy val metrics: Map[String, SQLMetric] = outerMetrics
+
+ override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
+ val nativeParquetSink =
+ Shims.get.createNativeParquetSinkExec(sparkSession, table, partition, child, metrics)
+ super.run(sparkSession, nativeParquetSink)
+ }
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetScanExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetScanExec.scala
new file mode 100644
index 000000000..a116a5430
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetScanExec.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.execution.FileSourceScanExec
+
+case class NativeParquetScanExec(basedFileScan: FileSourceScanExec)
+ extends NativeParquetScanBase(basedFileScan) {
+
+ override def simpleString(maxFields: Int): String =
+ s"$nodeName (${basedFileScan.simpleString(maxFields)})"
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetSinkExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetSinkExec.scala
new file mode 100644
index 000000000..cc44a9dbd
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetSinkExec.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetric
+
+import org.apache.auron.sparkver
+
+case class NativeParquetSinkExec(
+ sparkSession: SparkSession,
+ table: CatalogTable,
+ partition: Map[String, Option[String]],
+ override val child: SparkPlan,
+ override val metrics: Map[String, SQLMetric])
+ extends NativeParquetSinkBase(sparkSession, table, partition, child, metrics) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala
new file mode 100644
index 000000000..60b151206
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetric
+
+import org.apache.auron.sparkver
+
+case class NativePartialTakeOrderedExec(
+ limit: Long,
+ sortOrder: Seq[SortOrder],
+ override val child: SparkPlan,
+ override val metrics: Map[String, SQLMetric])
+ extends NativePartialTakeOrderedBase(limit, sortOrder, child, metrics) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeProjectExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeProjectExecProvider.scala
new file mode 100644
index 000000000..3274fd9eb
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeProjectExecProvider.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case object NativeProjectExecProvider {
+ @sparkver("4.0")
+ def provide(projectList: Seq[NamedExpression], child: SparkPlan): NativeProjectBase = {
+ import org.apache.spark.sql.execution.OrderPreservingUnaryExecNode
+ import org.apache.spark.sql.execution.PartitioningPreservingUnaryExecNode
+
+ case class NativeProjectExec(projectList: Seq[NamedExpression], override val child: SparkPlan)
+ extends NativeProjectBase(projectList, child)
+ with PartitioningPreservingUnaryExecNode
+ with OrderPreservingUnaryExecNode {
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override protected def outputExpressions = projectList
+
+ override protected def orderingExpressions = child.outputOrdering
+
+ override def nodeName: String = "NativeProject"
+ }
+ NativeProjectExec(projectList, child)
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeRenameColumnsExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeRenameColumnsExecProvider.scala
new file mode 100644
index 000000000..eabb98727
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeRenameColumnsExecProvider.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case object NativeRenameColumnsExecProvider {
+ @sparkver("4.0")
+ def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = {
+ import org.apache.spark.sql.catalyst.expressions.NamedExpression
+ import org.apache.spark.sql.catalyst.expressions.SortOrder
+ import org.apache.spark.sql.execution.OrderPreservingUnaryExecNode
+ import org.apache.spark.sql.execution.PartitioningPreservingUnaryExecNode
+
+ case class NativeRenameColumnsExec(
+ override val child: SparkPlan,
+ renamedColumnNames: Seq[String])
+ extends NativeRenameColumnsBase(child, renamedColumnNames)
+ with PartitioningPreservingUnaryExecNode
+ with OrderPreservingUnaryExecNode {
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override protected def outputExpressions: Seq[NamedExpression] = output
+
+ override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering
+ }
+ NativeRenameColumnsExec(child, renamedColumnNames)
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala
new file mode 100644
index 000000000..3a7988aa5
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import scala.collection.mutable
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent.Future
+
+import org.apache.spark._
+import org.apache.spark.rdd.MapPartitionsRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
+import org.apache.spark.shuffle.ShuffleWriteProcessor
+import org.apache.spark.sql.auron.NativeHelper
+import org.apache.spark.sql.auron.NativeRDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.auron.shuffle.AuronRssShuffleWriterBase
+import org.apache.spark.sql.execution.auron.shuffle.AuronShuffleWriterBase
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
+import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter
+
+import org.apache.auron.sparkver
+
+case class NativeShuffleExchangeExec(
+ override val outputPartitioning: Partitioning,
+ override val child: SparkPlan,
+ _shuffleOrigin: Option[Any] = None)
+ extends NativeShuffleExchangeBase(outputPartitioning, child) {
+
+ // NOTE: coordinator can be null after serialization/deserialization,
+ // e.g. it can be null on the Executor side
+ lazy val writeMetrics: Map[String, SQLMetric] = (mutable.LinkedHashMap[String, SQLMetric]() ++
+ SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) ++
+ mutable.LinkedHashMap(
+ NativeHelper
+ .getDefaultNativeMetrics(sparkContext)
+ .filterKeys(Set(
+ "stage_id",
+ "mem_spill_count",
+ "mem_spill_size",
+ "mem_spill_iotime",
+ "disk_spill_size",
+ "disk_spill_iotime",
+ "shuffle_write_total_time",
+ "shuffle_read_total_time"))
+ .toSeq: _*)).toMap
+
+ lazy val readMetrics: Map[String, SQLMetric] =
+ SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+
+ override lazy val metrics: Map[String, SQLMetric] =
+ (mutable.LinkedHashMap[String, SQLMetric]() ++
+ readMetrics ++
+ writeMetrics ++
+ Map(
+ "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+ "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions"))).toMap
+
+ // 'mapOutputStatisticsFuture' is only needed when enable AQE.
+ @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
+ if (inputRDD.getNumPartitions == 0) {
+ Future.successful(null)
+ } else {
+ sparkContext
+ .submitMapStage(shuffleDependency)
+ .map(stat => new MapOutputStatistics(stat.shuffleId, stat.bytesByPartitionId))
+ }
+ }
+
+ override def numMappers: Int = shuffleDependency.rdd.getNumPartitions
+
+ override def numPartitions: Int = shuffleDependency.partitioner.numPartitions
+
+ override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = {
+ new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs)
+ }
+
+ override def runtimeStatistics: Statistics = {
+ val dataSize = metrics("dataSize").value
+ val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
+ Statistics(dataSize, Some(rowCount))
+ }
+
+ /**
+ * Caches the created ShuffleRowRDD so we can reuse that.
+ */
+ private var cachedShuffleRDD: ShuffledRowRDD = _
+
+ override protected def doExecuteNonNative(): RDD[InternalRow] = {
+ // Returns the same ShuffleRowRDD if this plan is used by multiple plans.
+ if (cachedShuffleRDD == null) {
+ cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics)
+ }
+ cachedShuffleRDD
+ }
+
+ override def createNativeShuffleWriteProcessor(
+ metrics: Map[String, SQLMetric],
+ numPartitions: Int): ShuffleWriteProcessor = {
+
+ new ShuffleWriteProcessor {
+ override protected def createMetricsReporter(
+ context: TaskContext): ShuffleWriteMetricsReporter = {
+ new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics)
+ }
+
+ override def write(
+ inputs: Iterator[_],
+ dep: ShuffleDependency[_, _, _],
+ mapId: Long,
+ mapIndex: Int,
+ context: TaskContext): MapStatus = {
+
+ // [SPARK-44605][CORE] Refined the internal ShuffleWriteProcessor API.
+ // Due to the restructuring of the write method in the API, we optimized and refactored the original Partition.
+ val rdd = dep.rdd
+ val partition = rdd.partitions(mapIndex)
+
+ val writer = SparkEnv.get.shuffleManager.getWriter(
+ dep.shuffleHandle,
+ mapId,
+ context,
+ createMetricsReporter(context))
+
+ writer match {
+ case writer: AuronRssShuffleWriterBase[_, _] =>
+ writer.nativeRssShuffleWrite(
+ rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[NativeRDD],
+ dep,
+ mapId.toInt,
+ context,
+ partition,
+ numPartitions)
+
+ case writer: AuronShuffleWriterBase[_, _] =>
+ writer.nativeShuffleWrite(
+ rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[NativeRDD],
+ dep,
+ mapId.toInt,
+ context,
+ partition)
+ }
+ writer.stop(true).get
+ }
+ }
+ }
+
+ @sparkver("4.0")
+ override def advisoryPartitionSize: Option[Long] = None
+
+ @sparkver("4.0")
+ override def shuffleOrigin = {
+ import org.apache.spark.sql.execution.exchange.ShuffleOrigin;
+ _shuffleOrigin.get.asInstanceOf[ShuffleOrigin]
+ }
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override def shuffleId: Int = {
+ shuffleDependency.shuffleId;
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala
new file mode 100644
index 000000000..b532aac17
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeSortExec(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ override val child: SparkPlan)
+ extends NativeSortBase(sortOrder, global, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala
new file mode 100644
index 000000000..395c1527c
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeTakeOrderedExec(
+ limit: Long,
+ sortOrder: Seq[SortOrder],
+ override val child: SparkPlan)
+ extends NativeTakeOrderedBase(limit, sortOrder, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala
new file mode 100644
index 000000000..30e092cd5
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeUnionExec(
+ override val children: Seq[SparkPlan],
+ override val output: Seq[Attribute])
+ extends NativeUnionBase(children, output) {
+
+ @sparkver("4.0")
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
+ copy(children = newChildren)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala
new file mode 100644
index 000000000..c26a00f84
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.auron.sparkver
+
+case class NativeWindowExec(
+ windowExpression: Seq[NamedExpression],
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ groupLimit: Option[Int],
+ override val child: SparkPlan)
+ extends NativeWindowBase(windowExpression, partitionSpec, orderSpec, groupLimit, child) {
+
+ @sparkver("4.0")
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronBlockStoreShuffleReader.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronBlockStoreShuffleReader.scala
new file mode 100644
index 000000000..270d17b39
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronBlockStoreShuffleReader.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.shuffle
+
+import java.io.InputStream
+
+import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator}
+
+import org.apache.auron.sparkver
+
+class AuronBlockStoreShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ context: TaskContext,
+ readMetrics: ShuffleReadMetricsReporter,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+ shouldBatchFetch: Boolean = false)
+ extends AuronBlockStoreShuffleReaderBase[K, C](handle, context)
+ with Logging {
+
+ override def readBlocks(): Iterator[InputStream] = {
+ @sparkver("4.0")
+ def fetchIterator = new ShuffleBlockFetcherIterator(
+ context,
+ blockManager.blockStoreClient,
+ blockManager,
+ mapOutputTracker,
+ blocksByAddress,
+ (_, inputStream) => inputStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+ false, // checksums not supported
+ "ChecksumAlgorithmsNotSupported",
+ readMetrics,
+ fetchContinuousBlocksInBatch).toCompletionIterator.map(_._2)
+
+ fetchIterator
+ }
+
+ private def fetchContinuousBlocksInBatch: Boolean = {
+ val conf = SparkEnv.get.conf
+ val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
+ val compressed = conf.get(config.SHUFFLE_COMPRESS)
+ val codecConcatenation = if (compressed) {
+ CompressionCodec.supportsConcatenationOfSerializedStreams(
+ CompressionCodec.createCodec(conf))
+ } else {
+ true
+ }
+ val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)
+
+ val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
+ (!compressed || codecConcatenation) && !useOldFetchProtocol
+ if (shouldBatchFetch && !doBatchFetch) {
+ logDebug(
+ "The feature tag of continuous shuffle block fetching is set to true, but " +
+ "we can not enable the feature because other conditions are not satisfied. " +
+ s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " +
+ s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " +
+ s"$useOldFetchProtocol.")
+ }
+ doBatchFetch
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleManagerBase.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleManagerBase.scala
new file mode 100644
index 000000000..ab9aabc3a
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleManagerBase.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.shuffle
+
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.shuffle._
+import org.apache.spark.sql.execution.auron.shuffle.AuronShuffleDependency.isArrowShuffle
+
+import org.apache.auron.sparkver
+
+abstract class AuronRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManager with Logging {
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle
+
+ override def unregisterShuffle(shuffleId: Int): Boolean
+
+ def getAuronRssShuffleReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): AuronRssShuffleReaderBase[K, C]
+
+ def getAuronRssShuffleReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): AuronRssShuffleReaderBase[K, C]
+
+ def getRssShuffleReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
+
+ def getRssShuffleReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
+
+ def getAuronRssShuffleWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): AuronRssShuffleWriterBase[K, V]
+
+ def getRssShuffleWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
+
+ @sparkver("4.0")
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+
+ if (isArrowShuffle(handle)) {
+ getAuronRssShuffleReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
+ } else {
+ getRssShuffleReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
+ }
+ }
+
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+
+ if (isArrowShuffle(handle)) {
+ getAuronRssShuffleWriter(handle, mapId, context, metrics)
+ } else {
+ getRssShuffleWriter(handle, mapId, context, metrics)
+ }
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleManager.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleManager.scala
new file mode 100644
index 000000000..852136213
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleManager.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.shuffle
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
+import org.apache.spark.sql.execution.auron.shuffle.AuronShuffleDependency.isArrowShuffle
+
+import org.apache.auron.sparkver
+
+class AuronShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+ val sortShuffleManager = new SortShuffleManager(conf)
+
+ // disable other off-heap memory usages
+ System.setProperty("spark.memory.offHeap.enabled", "false")
+ System.setProperty("io.netty.maxDirectMemory", "0")
+ System.setProperty("io.netty.noPreferDirect", "true")
+ System.setProperty("io.netty.noUnsafe", "true")
+
+ if (!conf.getBoolean("spark.shuffle.spill", defaultValue = true)) {
+ logWarning(
+ "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." +
+ " Shuffle will continue to spill to disk when necessary.")
+ }
+
+ override val shuffleBlockResolver: ShuffleBlockResolver =
+ sortShuffleManager.shuffleBlockResolver
+
+ /**
+ * (override) Obtains a [[ShuffleHandle]] to pass to tasks.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ sortShuffleManager.registerShuffle(shuffleId, dependency)
+ }
+
+ @sparkver("4.0")
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+
+ if (isArrowShuffle(handle)) {
+ val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]]
+
+ @sparkver("4.0")
+ def shuffleMergeFinalized = baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked
+
+ val (blocksByAddress, canEnableBatchFetch) =
+ if (shuffleMergeFinalized) {
+ val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId(
+ handle.shuffleId,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition)
+ (res.iter, res.enableBatchFetch)
+ } else {
+ val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition)
+ (address, true)
+ }
+
+ new AuronBlockStoreShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ blocksByAddress.map(tup => (tup._1, tup._2.toSeq)),
+ context,
+ metrics,
+ SparkEnv.get.blockManager,
+ SparkEnv.get.mapOutputTracker,
+ shouldBatchFetch =
+ canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context))
+ } else {
+ sortShuffleManager.getReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
+ }
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+
+ if (isArrowShuffle(handle)) {
+ new AuronShuffleWriter(metrics)
+ } else {
+ sortShuffleManager.getWriter(handle, mapId, context, metrics)
+ }
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ sortShuffleManager.unregisterShuffle(shuffleId)
+ }
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {
+ shuffleBlockResolver.stop()
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleWriter.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleWriter.scala
new file mode 100644
index 000000000..2394ba132
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleWriter.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.shuffle
+
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
+
+import org.apache.auron.sparkver
+
+class AuronShuffleWriter[K, V](metrics: ShuffleWriteMetricsReporter)
+ extends AuronShuffleWriterBase[K, V](metrics) {
+
+ @sparkver("4.0")
+ override def getPartitionLengths(): Array[Long] = partitionLengths
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
new file mode 100644
index 000000000..85d2564b6
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.joins.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.auron.plan.BroadcastLeft
+import org.apache.spark.sql.execution.auron.plan.BroadcastRight
+import org.apache.spark.sql.execution.auron.plan.BroadcastSide
+import org.apache.spark.sql.execution.auron.plan.NativeBroadcastJoinBase
+import org.apache.spark.sql.execution.joins.HashJoin
+
+import org.apache.auron.sparkver
+
+case class NativeBroadcastJoinExec(
+ override val left: SparkPlan,
+ override val right: SparkPlan,
+ override val outputPartitioning: Partitioning,
+ override val leftKeys: Seq[Expression],
+ override val rightKeys: Seq[Expression],
+ override val joinType: JoinType,
+ broadcastSide: BroadcastSide)
+ extends NativeBroadcastJoinBase(
+ left,
+ right,
+ outputPartitioning,
+ leftKeys,
+ rightKeys,
+ joinType,
+ broadcastSide)
+ with HashJoin {
+
+ override val condition: Option[Expression] = None
+
+ @sparkver("4.0")
+ override def buildSide: org.apache.spark.sql.catalyst.optimizer.BuildSide =
+ broadcastSide match {
+ case BroadcastLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
+ case BroadcastRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
+ }
+
+ @sparkver("4.0")
+ override def requiredChildDistribution = {
+ import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution
+ import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
+ import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+
+ def mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false)
+ broadcastSide match {
+ case BroadcastLeft =>
+ BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
+ case BroadcastRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+ }
+ }
+
+ override def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression] =
+ HashJoin.rewriteKeyExpr(exprs)
+
+ @sparkver("4.0")
+ override def supportCodegen: Boolean = false
+
+ @sparkver("4.0")
+ override def inputRDDs() = {
+ throw new NotImplementedError("NativeBroadcastJoin dose not support codegen")
+ }
+
+ @sparkver("4.0")
+ override protected def prepareRelation(
+ ctx: org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext)
+ : org.apache.spark.sql.execution.joins.HashedRelationInfo = {
+ throw new NotImplementedError("NativeBroadcastJoin dose not support codegen")
+ }
+
+ @sparkver("4.0")
+ override protected def withNewChildrenInternal(
+ newLeft: SparkPlan,
+ newRight: SparkPlan): SparkPlan =
+ copy(left = newLeft, right = newRight)
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
new file mode 100644
index 000000000..d24f2d6d9
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.joins.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.auron.plan.BuildSide
+import org.apache.spark.sql.execution.auron.plan.NativeShuffledHashJoinBase
+import org.apache.spark.sql.execution.joins.HashJoin
+
+import org.apache.auron.sparkver
+
+case object NativeShuffledHashJoinExecProvider {
+
+ @sparkver("4.0")
+ def provide(
+ left: SparkPlan,
+ right: SparkPlan,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ isSkewJoin: Boolean): NativeShuffledHashJoinBase = {
+
+ import org.apache.spark.rdd.RDD
+ import org.apache.spark.sql.catalyst.InternalRow
+ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+
+ case class NativeShuffledHashJoinExec(
+ override val left: SparkPlan,
+ override val right: SparkPlan,
+ override val leftKeys: Seq[Expression],
+ override val rightKeys: Seq[Expression],
+ override val joinType: JoinType,
+ buildSide: BuildSide,
+ skewJoin: Boolean)
+ extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys, joinType, buildSide)
+ with org.apache.spark.sql.execution.joins.ShuffledJoin {
+
+ override def condition: Option[Expression] = None
+
+ override def isSkewJoin: Boolean = false
+
+ override def supportCodegen: Boolean = false
+
+ override def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression] =
+ HashJoin.rewriteKeyExpr(exprs)
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ throw new NotImplementedError("NativeShuffledHash dose not support codegen")
+ }
+
+ override protected def doProduce(ctx: CodegenContext): String = {
+ throw new NotImplementedError("NativeShuffledHash dose not support codegen")
+ }
+
+ override protected def withNewChildrenInternal(
+ newLeft: SparkPlan,
+ newRight: SparkPlan): SparkPlan =
+ copy(left = newLeft, right = newRight)
+
+ override def nodeName: String =
+ "NativeShuffledHashJoin" + (if (skewJoin) "(skew=true)" else "")
+ }
+ NativeShuffledHashJoinExec(left, right, leftKeys, rightKeys, joinType, buildSide, isSkewJoin)
+ }
+}
diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeSortMergeJoinExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeSortMergeJoinExecProvider.scala
new file mode 100644
index 000000000..e59ca8f6e
--- /dev/null
+++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeSortMergeJoinExecProvider.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.joins.auron.plan
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.auron.plan.NativeSortMergeJoinBase
+
+import org.apache.auron.sparkver
+
+case object NativeSortMergeJoinExecProvider {
+
+ @sparkver("4.0")
+ def provide(
+ left: SparkPlan,
+ right: SparkPlan,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ skewJoin: Boolean): NativeSortMergeJoinBase = {
+
+ import org.apache.spark.rdd.RDD
+ import org.apache.spark.sql.catalyst.InternalRow
+ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+
+ case class NativeSortMergeJoinExec(
+ override val left: SparkPlan,
+ override val right: SparkPlan,
+ override val leftKeys: Seq[Expression],
+ override val rightKeys: Seq[Expression],
+ override val joinType: JoinType,
+ skewJoin: Boolean)
+ extends NativeSortMergeJoinBase(left, right, leftKeys, rightKeys, joinType)
+ with org.apache.spark.sql.execution.joins.ShuffledJoin {
+
+ override def condition: Option[Expression] = None
+
+ override def isSkewJoin: Boolean = false
+
+ override def supportCodegen: Boolean = false
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ throw new NotImplementedError("NativeSortMergeJoin dose not support codegen")
+ }
+
+ override protected def doProduce(ctx: CodegenContext): String = {
+ throw new NotImplementedError("NativeSortMergeJoin dose not support codegen")
+ }
+
+ override protected def withNewChildrenInternal(
+ newLeft: SparkPlan,
+ newRight: SparkPlan): SparkPlan =
+ copy(left = newLeft, right = newRight)
+
+ override def nodeName: String =
+ "NativeSortMergeJoin" + (if (skewJoin) "(skew=true)" else "")
+ }
+ NativeSortMergeJoinExec(left, right, leftKeys, rightKeys, joinType, skewJoin)
+ }
+}
diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala
new file mode 100644
index 000000000..96e39b3e1
--- /dev/null
+++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.auron.sparkverEnableMembers
+
+@sparkverEnableMembers("4.0")
+class AuronAdaptiveQueryExecSuite
+ extends org.apache.spark.sql.QueryTest
+ with BaseAuronSQLSuite
+ with org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper {
+
+ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
+ import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, SparkPlan}
+ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec}
+ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+ import org.apache.spark.sql.execution.exchange.Exchange
+ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates}
+ import org.apache.spark.sql.internal.SQLConf
+ import org.apache.spark.sql.test.SQLTestData.TestData
+ import testImplicits._
+
+ test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") {
+ withTempView("v") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") {
+
+ spark.sparkContext
+ .parallelize((1 to 10).map(i => TestData(if (i > 4) 5 else i, i.toString)), 3)
+ .toDF("c1", "c2")
+ .createOrReplaceTempView("v")
+
+ def checkPartitionNumber(
+ query: String,
+ skewedPartitionNumber: Int,
+ totalNumber: Int): Unit = {
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query)
+ val read = collect(adaptive) { case read: AQEShuffleReadExec =>
+ read
+ }
+ assert(read.size == 1)
+ assert(
+ read.head.partitionSpecs.count(_.isInstanceOf[PartialReducerPartitionSpec]) ==
+ skewedPartitionNumber)
+ assert(read.head.partitionSpecs.size == totalNumber)
+ }
+
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") {
+ // partition size [0, 75, 45, 68, 34]
+ checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4)
+ checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 0, 3)
+ }
+
+ // no skewed partition should be optimized
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10000") {
+ checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 0, 1)
+ }
+ }
+ }
+ }
+
+ private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = {
+ var finalPlanCnt = 0
+ var hasMetricsEvent = false
+ val listener = new SparkListener {
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case SparkListenerSQLAdaptiveExecutionUpdate(_, _, sparkPlanInfo) =>
+ if (sparkPlanInfo.simpleString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) {
+ finalPlanCnt += 1
+ }
+ case _: SparkListenerSQLAdaptiveSQLMetricUpdates =>
+ hasMetricsEvent = true
+ case _ => // ignore other events
+ }
+ }
+ }
+ spark.sparkContext.addSparkListener(listener)
+
+ val dfAdaptive = sql(query)
+ val planBefore = dfAdaptive.queryExecution.executedPlan
+ assert(planBefore.toString.startsWith("AdaptiveSparkPlan isFinalPlan=false"))
+ val result = dfAdaptive.collect()
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ val df = sql(query)
+ checkAnswer(df, result)
+ }
+ val planAfter = dfAdaptive.queryExecution.executedPlan
+ assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true"))
+ val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ // AQE will post `SparkListenerSQLAdaptiveExecutionUpdate` twice in case of subqueries that
+ // exist out of query stages.
+ val expectedFinalPlanCnt = adaptivePlan.find(_.subqueries.nonEmpty).map(_ => 2).getOrElse(1)
+ assert(finalPlanCnt == expectedFinalPlanCnt)
+ spark.sparkContext.removeSparkListener(listener)
+
+ val expectedMetrics = findInMemoryTable(planAfter).nonEmpty ||
+ subqueriesAll(planAfter).nonEmpty
+ assert(hasMetricsEvent == expectedMetrics)
+
+ val exchanges = adaptivePlan.collect { case e: Exchange =>
+ e
+ }
+ assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.")
+ (dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
+ }
+
+ private def findInMemoryTable(plan: SparkPlan): Seq[InMemoryTableScanExec] = {
+ collect(plan) {
+ case c: InMemoryTableScanExec
+ if c.relation.cachedPlan.isInstanceOf[AdaptiveSparkPlanExec] =>
+ c
+ }
+ }
+}
diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
new file mode 100644
index 000000000..2f7e57074
--- /dev/null
+++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+
+import org.apache.auron.util.AuronTestUtils
+
+class AuronFunctionSuite
+ extends org.apache.spark.sql.QueryTest
+ with BaseAuronSQLSuite
+ with AdaptiveSparkPlanHelper {
+
+ test("sum function with float input") {
+ if (AuronTestUtils.isSparkV31OrGreater) {
+ withTable("t1") {
+ sql("create table t1 using parquet as select 1.0f as c1")
+ val df = sql("select sum(c1) from t1")
+ checkAnswer(df, Seq(Row(1.0)))
+ }
+ }
+ }
+
+ test("sha2 function") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select 'spark' as c1, '3.x' as version")
+ val functions =
+ """
+ |select
+ | sha2(concat(c1, version), 256) as sha0,
+ | sha2(concat(c1, version), 256) as sha256,
+ | sha2(concat(c1, version), 224) as sha224,
+ | sha2(concat(c1, version), 384) as sha384,
+ | sha2(concat(c1, version), 512) as sha512
+ |from t1
+ |""".stripMargin
+ val df = sql(functions)
+ checkAnswer(
+ df,
+ Seq(
+ Row(
+ "562d20689257f3f3a04ee9afb86d0ece2af106cf6c6e5e7d266043088ce5fbc0",
+ "562d20689257f3f3a04ee9afb86d0ece2af106cf6c6e5e7d266043088ce5fbc0",
+ "d0c8e9ccd5c7b3fdbacd2cfd6b4d65ca8489983b5e8c7c64cd77b634",
+ "77c1199808053619c29e9af2656e1ad2614772f6ea605d5757894d6aec2dfaf34ff6fd662def3b79e429e9ae5ecbfed1",
+ "c4e27d35517ca62243c1f322d7922dac175830be4668e8a1cf3befdcd287bb5b6f8c5f041c9d89e4609c8cfa242008c7c7133af1685f57bac9052c1212f1d089")))
+ }
+ }
+
+ test("spark hash function") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select array(1, 2) as arr")
+ val functions =
+ """
+ |select hash(arr) from t1
+ |""".stripMargin
+ val df = sql(functions)
+ checkAnswer(df, Seq(Row(-222940379)))
+ }
+ }
+
+ test("expm1 function") {
+ withTable("t1") {
+ sql("create table t1(c1 double) using parquet")
+ sql("insert into t1 values(0.0), (1.1), (2.2)")
+ val df = sql("select expm1(c1) from t1")
+ checkAnswer(df, Seq(Row(0.0), Row(2.0041660239464334), Row(8.025013499434122)))
+ }
+ }
+
+ test("factorial function") {
+ withTable("t1") {
+ sql("create table t1(c1 int) using parquet")
+ sql("insert into t1 values(5)")
+ val df = sql("select factorial(c1) from t1")
+ checkAnswer(df, Seq(Row(120)))
+ }
+ }
+
+ test("hex function") {
+ withTable("t1") {
+ sql("create table t1(c1 int, c2 string) using parquet")
+ sql("insert into t1 values(17, 'Spark SQL')")
+ val df = sql("select hex(c1), hex(c2) from t1")
+ checkAnswer(df, Seq(Row("11", "537061726B2053514C")))
+ }
+ }
+
+ test("stddev_samp function with UDAF fallback") {
+ withSQLConf("spark.auron.udafFallback.enable" -> "true") {
+ withTable("t1") {
+ sql("create table t1(c1 double) using parquet")
+ sql("insert into t1 values(10.0), (20.0), (30.0), (31.0), (null)")
+ val df = sql("select stddev_samp(c1) from t1")
+ checkAnswer(df, Seq(Row(9.844626283748239)))
+ }
+ }
+ }
+
+ test("regexp_extract function with UDF failback") {
+ withTable("t1") {
+ sql("create table t1(c1 string) using parquet")
+ sql("insert into t1 values('Auron Spark SQL')")
+ val df = sql("select regexp_extract(c1, '^A(.*)L$', 1) from t1")
+ checkAnswer(df, Seq(Row("uron Spark SQ")))
+ }
+ }
+}
diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronQuerySuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronQuerySuite.scala
new file mode 100644
index 000000000..5b9787d36
--- /dev/null
+++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronQuerySuite.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.Row
+
+import org.apache.auron.util.AuronTestUtils
+
+class AuronQuerySuite
+ extends org.apache.spark.sql.QueryTest
+ with BaseAuronSQLSuite
+ with AuronSQLTestHelper {
+ import testImplicits._
+
+ test("test partition path has url encoded character") {
+ withTable("t1") {
+ sql(
+ "create table t1 using parquet PARTITIONED BY (part) as select 1 as c1, 2 as c2, 'test test' as part")
+ val df = sql("select * from t1")
+ checkAnswer(df, Seq(Row(1, 2, "test test")))
+ }
+ }
+
+ test("empty output in bnlj") {
+ withTable("t1", "t2") {
+ sql("create table t1 using parquet as select 1 as c1, 2 as c2")
+ sql("create table t2 using parquet as select 1 as c1, 3 as c3")
+ val df = sql("select 1 from t1 left join t2")
+ checkAnswer(df, Seq(Row(1)))
+ }
+ }
+
+ test("test filter with year function") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select '2024-12-18' as event_time")
+ checkAnswer(
+ sql("""
+ |select year, count(*)
+ |from (select event_time, year(event_time) as year from t1) t
+ |where year <= 2024
+ |group by year
+ |""".stripMargin),
+ Seq(Row(2024, 1)))
+ }
+ }
+
+ test("test select multiple spark ext functions with the same signature") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select '2024-12-18' as event_time")
+ checkAnswer(sql("select year(event_time), month(event_time) from t1"), Seq(Row(2024, 12)))
+ }
+ }
+
+ test("test parquet/orc format table with complex data type") {
+ def createTableStatement(format: String): String = {
+ s"""create table test_with_complex_type(
+ |id bigint comment 'pk',
+ |m map comment 'test read map type',
+ |l array comment 'test read list type',
+ |s string comment 'string type'
+ |) USING $format
+ |""".stripMargin
+ }
+ Seq("parquet", "orc").foreach(format =>
+ withTable("test_with_complex_type") {
+ sql(createTableStatement(format))
+ sql(
+ "insert into test_with_complex_type select 1 as id, map('zero', '0', 'one', '1') as m, array('test','auron') as l, 'auron' as s")
+ checkAnswer(
+ sql("select id,l,m from test_with_complex_type"),
+ Seq(Row(1, ArrayBuffer("test", "auron"), Map("one" -> "1", "zero" -> "0"))))
+ })
+ }
+
+ test("binary type in range partitioning") {
+ withTable("t1", "t2") {
+ sql("create table t1(c1 binary, c2 int) using parquet")
+ sql("insert into t1 values (cast('test1' as binary), 1), (cast('test2' as binary), 2)")
+ val df = sql("select c2 from t1 order by c1")
+ checkAnswer(df, Seq(Row(1), Row(2)))
+ }
+ }
+
+ test("repartition over MapType") {
+ withTable("t_map") {
+ sql("create table t_map using parquet as select map('a', '1', 'b', '2') as data_map")
+ val df = sql("SELECT /*+ repartition(10) */ data_map FROM t_map")
+ checkAnswer(df, Seq(Row(Map("a" -> "1", "b" -> "2"))))
+ }
+ }
+
+ test("repartition over MapType with ArrayType") {
+ withTable("t_map_struct") {
+ sql(
+ "create table t_map_struct using parquet as select named_struct('m', map('x', '1')) as data_struct")
+ val df = sql("SELECT /*+ repartition(10) */ data_struct FROM t_map_struct")
+ checkAnswer(df, Seq(Row(Row(Map("x" -> "1")))))
+ }
+ }
+
+ test("repartition over ArrayType with MapType") {
+ withTable("t_array_map") {
+ sql("""
+ |create table t_array_map using parquet as
+ |select array(map('k1', 1, 'k2', 2), map('k3', 3)) as array_of_map
+ |""".stripMargin)
+ val df = sql("SELECT /*+ repartition(10) */ array_of_map FROM t_array_map")
+ checkAnswer(df, Seq(Row(Seq(Map("k1" -> 1, "k2" -> 2), Map("k3" -> 3)))))
+ }
+ }
+
+ test("repartition over StructType with MapType") {
+ withTable("t_struct_map") {
+ sql("""
+ |create table t_struct_map using parquet as
+ |select named_struct('id', 101, 'metrics', map('ctr', 0.123d, 'cvr', 0.045d)) as user_metrics
+ |""".stripMargin)
+ val df = sql("SELECT /*+ repartition(10) */ user_metrics FROM t_struct_map")
+ checkAnswer(df, Seq(Row(Row(101, Map("ctr" -> 0.123, "cvr" -> 0.045)))))
+ }
+ }
+
+ test("repartition over MapType with StructType") {
+ withTable("t_map_struct_value") {
+ sql("""
+ |create table t_map_struct_value using parquet as
+ |select map(
+ | 'item1', named_struct('count', 3, 'score', 4.5d),
+ | 'item2', named_struct('count', 7, 'score', 9.1d)
+ |) as map_struct_value
+ |""".stripMargin)
+ val df = sql("SELECT /*+ repartition(10) */ map_struct_value FROM t_map_struct_value")
+ checkAnswer(df, Seq(Row(Map("item1" -> Row(3, 4.5), "item2" -> Row(7, 9.1)))))
+ }
+ }
+
+ test("repartition over nested MapType") {
+ withTable("t_nested_map") {
+ sql("""
+ |create table t_nested_map using parquet as
+ |select map(
+ | 'outer1', map('inner1', 10, 'inner2', 20),
+ | 'outer2', map('inner3', 30)
+ |) as nested_map
+ |""".stripMargin)
+ val df = sql("SELECT /*+ repartition(10) */ nested_map FROM t_nested_map")
+ checkAnswer(
+ df,
+ Seq(Row(
+ Map("outer1" -> Map("inner1" -> 10, "inner2" -> 20), "outer2" -> Map("inner3" -> 30)))))
+ }
+ }
+
+ test("repartition over ArrayType of StructType with MapType") {
+ withTable("t_array_struct_map") {
+ sql("""
+ |create table t_array_struct_map using parquet as
+ |select array(
+ | named_struct('name', 'user1', 'features', map('f1', 1.0d, 'f2', 2.0d)),
+ | named_struct('name', 'user2', 'features', map('f3', 3.5d))
+ |) as user_feature_array
+ |""".stripMargin)
+ val df = sql("SELECT /*+ repartition(10) */ user_feature_array FROM t_array_struct_map")
+ checkAnswer(
+ df,
+ Seq(
+ Row(
+ Seq(Row("user1", Map("f1" -> 1.0f, "f2" -> 2.0f)), Row("user2", Map("f3" -> 3.5f))))))
+ }
+ }
+
+ test("log function with negative input") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select -1 as c1")
+ val df = sql("select ln(c1) from t1")
+ checkAnswer(df, Seq(Row(null)))
+ }
+ }
+
+ test("floor function with long input") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select 1L as c1, 2.2 as c2")
+ val df = sql("select floor(c1), floor(c2) from t1")
+ checkAnswer(df, Seq(Row(1, 2)))
+ }
+ }
+
+ test("SPARK-32234 read ORC table with column names all starting with '_col'") {
+ withTable("test_hive_orc_impl") {
+ spark.sql(s"""
+ | CREATE TABLE test_hive_orc_impl
+ | (_col1 INT, _col2 STRING, _col3 INT)
+ | USING ORC
+ """.stripMargin)
+ spark.sql(s"""
+ | INSERT INTO
+ | test_hive_orc_impl
+ | VALUES(9, '12', 2020)
+ """.stripMargin)
+
+ val df = spark.sql("SELECT _col2 FROM test_hive_orc_impl")
+ checkAnswer(df, Row("12"))
+ }
+ }
+
+ test("SPARK-32864: Support ORC forced positional evolution") {
+ if (AuronTestUtils.isSparkV32OrGreater) {
+ Seq(true, false).foreach { forcePositionalEvolution =>
+ withEnvConf(
+ AuronConf.ORC_FORCE_POSITIONAL_EVOLUTION.key -> forcePositionalEvolution.toString) {
+ withTempPath { f =>
+ val path = f.getCanonicalPath
+ Seq[(Integer, Integer)]((1, 2), (3, 4), (5, 6), (null, null))
+ .toDF("c1", "c2")
+ .write
+ .orc(path)
+ val correctAnswer = Seq(Row(1, 2), Row(3, 4), Row(5, 6), Row(null, null))
+ checkAnswer(spark.read.orc(path), correctAnswer)
+
+ withTable("t") {
+ sql(s"CREATE EXTERNAL TABLE t(c3 INT, c2 INT) USING ORC LOCATION '$path'")
+
+ val expected = if (forcePositionalEvolution) {
+ correctAnswer
+ } else {
+ Seq(Row(null, 2), Row(null, 4), Row(null, 6), Row(null, null))
+ }
+
+ checkAnswer(spark.table("t"), expected)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-32864: Support ORC forced positional evolution with partitioned table") {
+ if (AuronTestUtils.isSparkV32OrGreater) {
+ Seq(true, false).foreach { forcePositionalEvolution =>
+ withEnvConf(
+ AuronConf.ORC_FORCE_POSITIONAL_EVOLUTION.key -> forcePositionalEvolution.toString) {
+ withTempPath { f =>
+ val path = f.getCanonicalPath
+ Seq[(Integer, Integer, Integer)]((1, 2, 1), (3, 4, 2), (5, 6, 3), (null, null, 4))
+ .toDF("c1", "c2", "p")
+ .write
+ .partitionBy("p")
+ .orc(path)
+ val correctAnswer = Seq(Row(1, 2, 1), Row(3, 4, 2), Row(5, 6, 3), Row(null, null, 4))
+ checkAnswer(spark.read.orc(path), correctAnswer)
+
+ withTable("t") {
+ sql(s"""
+ |CREATE TABLE t(c3 INT, c2 INT)
+ |USING ORC
+ |PARTITIONED BY (p int)
+ |LOCATION '$path'
+ |""".stripMargin)
+ sql("MSCK REPAIR TABLE t")
+ val expected = if (forcePositionalEvolution) {
+ correctAnswer
+ } else {
+ Seq(Row(null, 2, 1), Row(null, 4, 2), Row(null, 6, 3), Row(null, null, 4))
+ }
+
+ checkAnswer(spark.table("t"), expected)
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronSQLTestHelper.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronSQLTestHelper.scala
new file mode 100644
index 000000000..f37474a06
--- /dev/null
+++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronSQLTestHelper.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.spark.SparkEnv
+
+trait AuronSQLTestHelper {
+ def withEnvConf(pairs: (String, String)*)(f: => Unit): Unit = {
+ val env = SparkEnv.get
+ if (env == null) {
+ throw new IllegalStateException("SparkEnv is not initialized")
+ }
+ val conf = env.conf
+ val (keys, values) = pairs.unzip
+ val currentValues = keys.map { key =>
+ if (conf.contains(key)) {
+ Some(conf.get(key))
+ } else {
+ None
+ }
+ }
+ (keys, values).zipped.foreach { (k, v) =>
+ conf.set(k, v)
+ }
+ try f
+ finally {
+ keys.zip(currentValues).foreach {
+ case (key, Some(value)) => conf.set(key, value)
+ case (key, None) => conf.remove(key)
+ }
+ }
+ }
+}
diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/BaseAuronSQLSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/BaseAuronSQLSuite.scala
new file mode 100644
index 000000000..6e30c960f
--- /dev/null
+++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/BaseAuronSQLSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+trait BaseAuronSQLSuite extends SharedSparkSession {
+
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set("spark.sql.extensions", "org.apache.spark.sql.auron.AuronSparkSessionExtension")
+ .set(
+ "spark.shuffle.manager",
+ "org.apache.spark.sql.execution.auron.shuffle.AuronShuffleManager")
+ .set("spark.memory.offHeap.enabled", "false")
+ .set("spark.auron.enable", "true")
+ }
+
+}
diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
new file mode 100644
index 000000000..d5a1a60c1
--- /dev/null
+++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType}
+
+import org.apache.auron.protobuf.ScalarFunction
+
+class NativeConvertersSuite extends QueryTest with BaseAuronSQLSuite with AuronSQLTestHelper {
+
+ private def assertTrimmedCast(rawValue: String, targetType: DataType): Unit = {
+ val expr = Cast(Literal.create(rawValue, StringType), targetType)
+ val nativeExpr = NativeConverters.convertExpr(expr)
+
+ assert(nativeExpr.hasTryCast)
+ val childExpr = nativeExpr.getTryCast.getExpr
+ assert(childExpr.hasScalarFunction)
+ val scalarFn = childExpr.getScalarFunction
+ assert(scalarFn.getFun == ScalarFunction.Trim)
+ assert(scalarFn.getArgsCount == 1 && scalarFn.getArgs(0).hasLiteral)
+ }
+
+ private def assertNonTrimmedCast(rawValue: String, targetType: DataType): Unit = {
+ val expr = Cast(Literal.create(rawValue, StringType), targetType)
+ val nativeExpr = NativeConverters.convertExpr(expr)
+
+ assert(nativeExpr.hasTryCast)
+ val childExpr = nativeExpr.getTryCast.getExpr
+ assert(!childExpr.hasScalarFunction)
+ assert(childExpr.hasLiteral)
+ }
+
+ test("cast from string to numeric adds trim wrapper before native cast when enabled") {
+ withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") {
+ assertTrimmedCast(" 42 ", IntegerType)
+ }
+ }
+
+ test("cast from string to boolean adds trim wrapper before native cast when enabled") {
+ withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") {
+ assertTrimmedCast(" true ", BooleanType)
+ }
+ }
+
+ test("cast trim disabled via auron conf") {
+ withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") {
+ assertNonTrimmedCast(" 42 ", IntegerType)
+ }
+ }
+
+ test("cast trim disabled via auron conf for boolean cast") {
+ withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") {
+ assertNonTrimmedCast(" true ", BooleanType)
+ }
+ }
+
+ test("cast with non-string child remains unchanged") {
+ val expr = Cast(Literal(1.5), IntegerType)
+ val nativeExpr = NativeConverters.convertExpr(expr)
+
+ assert(nativeExpr.hasTryCast)
+ val childExpr = nativeExpr.getTryCast.getExpr
+ assert(!childExpr.hasScalarFunction)
+ assert(childExpr.hasLiteral)
+ }
+}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
index 556ac7f31..58a9eccd9 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
@@ -413,13 +413,13 @@ object AuronConverters extends Logging {
getShuffleOrigin(exec))
}
- @sparkver(" 3.2 / 3.3 / 3.4 / 3.5")
+ @sparkver(" 3.2 / 3.3 / 3.4 / 3.5 / 4.0")
def getIsSkewJoinFromSHJ(exec: ShuffledHashJoinExec): Boolean = exec.isSkewJoin
@sparkver("3.0 / 3.1")
def getIsSkewJoinFromSHJ(exec: ShuffledHashJoinExec): Boolean = false
- @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5 / 4.0")
def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin)
@sparkver("3.0")
@@ -1094,7 +1094,7 @@ object AuronConverters extends Logging {
rddPartitioner = None,
rddDependencies = Nil,
false,
- (_partition, _taskContext) => {
+ (_, _) => {
val nativeEmptyExec = EmptyPartitionsExecNode
.newBuilder()
.setNumPartitions(outputPartitioning.numPartitions)
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala
index 988875adc..67a813242 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala
@@ -34,6 +34,8 @@ import org.apache.spark.sql.types.TimestampType
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.auron.sparkver
+
class AuronColumnarArray(data: AuronColumnVector, offset: Int, length: Int) extends ArrayData {
override def numElements: Int = length
@@ -154,4 +156,9 @@ class AuronColumnarArray(data: AuronColumnVector, offset: Int, length: Int) exte
override def setNullAt(ordinal: Int): Unit = {
throw new UnsupportedOperationException
}
+
+ @sparkver("4.0")
+ override def getVariant(i: Int): org.apache.spark.unsafe.types.VariantVal = {
+ throw new UnsupportedOperationException
+ }
}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala
index 6b24e0f55..65ee91610 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala
@@ -37,6 +37,8 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.auron.sparkver
+
class AuronColumnarBatchRow(columns: Array[AuronColumnVector], var rowId: Int = 0)
extends InternalRow {
override def numFields: Int = columns.length
@@ -133,4 +135,9 @@ class AuronColumnarBatchRow(columns: Array[AuronColumnVector], var rowId: Int =
override def setNullAt(ordinal: Int): Unit = {
throw new UnsupportedOperationException
}
+
+ @sparkver("4.0")
+ override def getVariant(i: Int): org.apache.spark.unsafe.types.VariantVal = {
+ throw new UnsupportedOperationException
+ }
}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala
index 34e7a717d..cfb5debeb 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala
@@ -37,6 +37,8 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.auron.sparkver
+
class AuronColumnarStruct(data: AuronColumnVector, rowId: Int) extends InternalRow {
override def numFields: Int = data.dataType.asInstanceOf[StructType].size
@@ -143,4 +145,9 @@ class AuronColumnarStruct(data: AuronColumnVector, rowId: Int) extends InternalR
override def setNullAt(ordinal: Int): Unit = {
throw new UnsupportedOperationException
}
+
+ @sparkver("4.0")
+ override def getVariant(i: Int): org.apache.spark.unsafe.types.VariantVal = {
+ throw new UnsupportedOperationException
+ }
}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala
index 7dafd564a..408aa68e5 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala
@@ -65,7 +65,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.BinaryType
-import org.apache.auron.{protobuf => pb}
+import org.apache.auron.{protobuf => pb, sparkver}
abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val child: SparkPlan)
extends BroadcastExchangeLike
@@ -137,9 +137,21 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi
dummyBroadcasted.asInstanceOf[Broadcast[T]]
}
- def doExecuteBroadcastNative[T](): broadcast.Broadcast[T] = {
+ @sparkver("3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5")
+ def getBroadcastTimeout: Long = {
val conf = SparkSession.getActiveSession.map(_.sqlContext.conf).orNull
- val timeout: Long = conf.broadcastTimeout
+ conf.broadcastTimeout
+ }
+
+ @sparkver("4.0")
+ def getBroadcastTimeout: Long = {
+ SparkSession.getActiveSession
+ .map(_.conf.get("spark.sql.broadcastTimeout").toLong)
+ .getOrElse(300L)
+ }
+
+ def doExecuteBroadcastNative[T](): broadcast.Broadcast[T] = {
+ val timeout: Long = getBroadcastTimeout
try {
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
} catch {
@@ -258,7 +270,31 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi
lazy val relationFuturePromise: Promise[Broadcast[Any]] = Promise[Broadcast[Any]]()
@transient
- lazy val relationFuture: Future[Broadcast[Any]] = {
+ lazy val relationFuture: Future[Broadcast[Any]] = getRelationFuture
+
+ @sparkver("4.0")
+ private def getRelationFuture = {
+ SQLExecution.withThreadLocalCaptured[Broadcast[Any]](
+ this.session.sqlContext.sparkSession,
+ BroadcastExchangeExec.executionContext) {
+ try {
+ sparkContext.setJobGroup(
+ getRunId.toString,
+ s"native broadcast exchange (runId $getRunId)",
+ interruptOnCancel = true)
+ val broadcasted = sparkContext.broadcast(collectNative().asInstanceOf[Any])
+ relationFuturePromise.trySuccess(broadcasted)
+ broadcasted
+ } catch {
+ case e: Throwable =>
+ relationFuturePromise.tryFailure(e)
+ throw e
+ }
+ }
+ }
+
+ @sparkver("3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5")
+ private def getRelationFuture = {
SQLExecution.withThreadLocalCaptured[Broadcast[Any]](
Shims.get.getSqlContext(this).sparkSession,
BroadcastExchangeExec.executionContext) {
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala
index 19ba9bdfb..8717d0bea 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala
@@ -77,7 +77,7 @@ abstract class AuronRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsRepor
def rssStop(success: Boolean): Option[MapStatus]
- @sparkver("3.2 / 3.3 / 3.4 / 3.5")
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0")
override def getPartitionLengths(): Array[Long] = rpw.getPartitionLengthMap
override def write(records: Iterator[Product2[K, V]]): Unit = {