From 102daec51e82ce42dac2a018947299cd88d1fee6 Mon Sep 17 00:00:00 2001 From: slfan1989 Date: Sun, 5 Oct 2025 17:16:35 +0800 Subject: [PATCH 1/5] [AURON#1404] Support for Spark 4.0.1 Compatibility in Auron. --- pom.xml | 17 +- spark-extension-shims-spark4/pom.xml | 98 +++ .../ForceApplyShuffledHashJoinInjector.java | 43 + ...ForceApplyShuffledHashJoinInterceptor.java | 32 + .../ValidateSparkPlanApplyInterceptor.java | 34 + .../sql/auron/ValidateSparkPlanInjector.java | 59 ++ .../auron/InterceptedValidateSparkPlan.scala | 79 ++ .../apache/spark/sql/auron/ShimsImpl.scala | 733 ++++++++++++++++++ .../auron/plan/ConvertToNativeExec.scala | 28 + .../execution/auron/plan/NativeAggExec.scala | 78 ++ .../plan/NativeBroadcastExchangeExec.scala | 49 ++ .../auron/plan/NativeExpandExec.scala | 34 + .../auron/plan/NativeFilterExec.scala | 30 + .../auron/plan/NativeGenerateExec.scala | 36 + .../auron/plan/NativeGlobalLimitExec.scala | 29 + .../auron/plan/NativeLocalLimitExec.scala | 29 + .../auron/plan/NativeOrcScanExec.scala | 26 + ...NativeParquetInsertIntoHiveTableExec.scala | 103 +++ .../auron/plan/NativeParquetScanExec.scala | 26 + .../auron/plan/NativeParquetSinkExec.scala | 37 + .../plan/NativePartialTakeOrderedExec.scala | 35 + .../plan/NativeProjectExecProvider.scala | 46 ++ .../NativeRenameColumnsExecProvider.scala | 47 ++ .../plan/NativeShuffleExchangeExec.scala | 197 +++++ .../execution/auron/plan/NativeSortExec.scala | 37 + .../auron/plan/NativeTakeOrderedExec.scala | 37 + .../auron/plan/NativeUnionExec.scala | 36 + .../auron/plan/NativeWindowExec.scala | 41 + .../AuronBlockStoreShuffleReader.scala | 89 +++ .../shuffle/AuronRssShuffleManagerBase.scala | 120 +++ .../auron/shuffle/AuronShuffleManager.scala | 134 ++++ .../auron/shuffle/AuronShuffleWriter.scala | 28 + .../auron/plan/NativeBroadcastJoinExec.scala | 96 +++ .../NativeShuffledHashJoinExecProvider.scala | 82 ++ .../NativeSortMergeJoinExecProvider.scala | 75 ++ .../auron/AuronAdaptiveQueryExecSuite.scala | 136 ++++ .../spark/sql/auron/AuronFunctionSuite.scala | 123 +++ .../spark/sql/auron/AuronQuerySuite.scala | 289 +++++++ .../spark/sql/auron/AuronSQLTestHelper.scala | 47 ++ .../spark/sql/auron/BaseAuronSQLSuite.scala | 34 + .../sql/auron/NativeConvertersSuite.scala | 83 ++ .../spark/sql/auron/AuronConverters.scala | 6 +- .../org/apache/spark/sql/auron/Shims.scala | 3 +- .../auron/columnar/AuronColumnarArray.scala | 7 +- .../columnar/AuronColumnarBatchRow.scala | 7 +- .../auron/columnar/AuronColumnarStruct.scala | 7 +- .../plan/NativeBroadcastExchangeBase.scala | 8 +- 47 files changed, 3433 insertions(+), 17 deletions(-) create mode 100644 spark-extension-shims-spark4/pom.xml create mode 100644 spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInjector.java create mode 100644 spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInterceptor.java create mode 100644 spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanApplyInterceptor.java create mode 100644 spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanInjector.java create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/InterceptedValidateSparkPlan.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeExpandExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeFilterExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGenerateExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcScanExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetInsertIntoHiveTableExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetScanExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetSinkExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeProjectExecProvider.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeRenameColumnsExecProvider.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronBlockStoreShuffleReader.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleManagerBase.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleManager.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleWriter.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala create mode 100644 spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeSortMergeJoinExecProvider.scala create mode 100644 spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala create mode 100644 spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala create mode 100644 spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronQuerySuite.scala create mode 100644 spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronSQLTestHelper.scala create mode 100644 spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/BaseAuronSQLSuite.scala create mode 100644 spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala diff --git a/pom.xml b/pom.xml index e8e8333d9..ca89d262a 100644 --- a/pom.xml +++ b/pom.xml @@ -68,7 +68,7 @@ 3.0.0 2.1.1 - 4.8.1 + 4.9.9 -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED @@ -379,7 +379,7 @@ com.diffplug.spotless spotless-maven-plugin - 2.30.0 + 2.45.0 @@ -576,6 +576,17 @@ + + spark-4.0 + + spark-4.0 + spark-extension-shims-spark4 + 3.2.10 + 3.9.9 + 4.0.1 + + + jdk-8 @@ -648,7 +659,7 @@ 2.13 - 2.13.8 + 2.13.13 diff --git a/spark-extension-shims-spark4/pom.xml b/spark-extension-shims-spark4/pom.xml new file mode 100644 index 000000000..2ab23fac8 --- /dev/null +++ b/spark-extension-shims-spark4/pom.xml @@ -0,0 +1,98 @@ + + + 4.0.0 + + org.apache.auron + auron-parent_${scalaVersion} + ${project.version} + ../pom.xml + + + spark-extension-shims-spark4_${scalaVersion} + jar + + http://maven.apache.org + + + UTF-8 + + + + + org.apache.auron + auron-common_${scalaVersion} + ${project.version} + + + org.apache.auron + spark-extension_${scalaVersion} + ${project.version} + + + org.scala-lang.modules + scala-java8-compat_${scalaVersion} + + + org.apache.spark + spark-core_${scalaVersion} + provided + + + org.apache.spark + spark-hive_${scalaVersion} + provided + + + org.apache.spark + spark-sql_${scalaVersion} + provided + + + org.apache.arrow + arrow-c-data + + + org.apache.arrow + arrow-compression + + + org.apache.arrow + arrow-memory-unsafe + + + org.apache.arrow + arrow-vector + + + + net.bytebuddy + byte-buddy + + + net.bytebuddy + byte-buddy-agent + + + + org.apache.auron + auron-common_${scalaVersion} + ${project.version} + test-jar + + + org.apache.spark + spark-core_${scalaVersion} + test-jar + + + org.apache.spark + spark-catalyst_${scalaVersion} + test-jar + + + org.apache.spark + spark-sql_${scalaVersion} + test-jar + + + diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInjector.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInjector.java new file mode 100644 index 000000000..00352cc41 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInjector.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron; + +import static net.bytebuddy.matcher.ElementMatchers.named; + +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.agent.ByteBuddyAgent; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.dynamic.ClassFileLocator; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.MethodDelegation; +import net.bytebuddy.pool.TypePool; + +public class ForceApplyShuffledHashJoinInjector { + public static void inject() { + ByteBuddyAgent.install(); + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + TypeDescription typeDescription = TypePool.Default.of(contextClassLoader) + .describe("org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper") + .resolve(); + new ByteBuddy() + .redefine(typeDescription, ClassFileLocator.ForClassLoader.of(contextClassLoader)) + .method(named("forceApplyShuffledHashJoin")) + .intercept(MethodDelegation.to(ForceApplyShuffledHashJoinInterceptor.class)) + .make() + .load(contextClassLoader, ClassLoadingStrategy.Default.INJECTION); + } +} diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInterceptor.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInterceptor.java new file mode 100644 index 000000000..8d7d0dd52 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ForceApplyShuffledHashJoinInterceptor.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron; + +import net.bytebuddy.implementation.bind.annotation.Argument; +import net.bytebuddy.implementation.bind.annotation.RuntimeType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ForceApplyShuffledHashJoinInterceptor { + private static final Logger logger = LoggerFactory.getLogger(ForceApplyShuffledHashJoinInterceptor.class); + + @RuntimeType + public static Object intercept(@Argument(0) Object conf) { + logger.debug("calling JoinSelectionHelper.forceApplyShuffledHashJoin() intercepted by auron"); + return AuronConf.FORCE_SHUFFLED_HASH_JOIN.booleanConf(); + } +} diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanApplyInterceptor.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanApplyInterceptor.java new file mode 100644 index 000000000..ea6a6603f --- /dev/null +++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanApplyInterceptor.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron; + +import net.bytebuddy.implementation.bind.annotation.Argument; +import net.bytebuddy.implementation.bind.annotation.RuntimeType; +import org.apache.spark.sql.execution.SparkPlan; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ValidateSparkPlanApplyInterceptor { + private static final Logger logger = LoggerFactory.getLogger(ValidateSparkPlanApplyInterceptor.class); + + @RuntimeType + public static Object intercept(@Argument(0) Object plan) { + logger.debug("calling ValidateSparkPlan.apply() intercepted by auron"); + InterceptedValidateSparkPlan$.MODULE$.validate((SparkPlan) plan); + return plan; + } +} diff --git a/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanInjector.java b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanInjector.java new file mode 100644 index 000000000..6685b1d7d --- /dev/null +++ b/spark-extension-shims-spark4/src/main/java/org/apache/spark/sql/auron/ValidateSparkPlanInjector.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron; + +import static net.bytebuddy.matcher.ElementMatchers.named; + +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.agent.ByteBuddyAgent; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.dynamic.ClassFileLocator; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.MethodDelegation; +import net.bytebuddy.pool.TypePool; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ValidateSparkPlanInjector { + + private static final Logger logger = LoggerFactory.getLogger(ValidateSparkPlanInjector.class); + private static boolean injected = false; + + public static synchronized void inject() { + if (injected) { + logger.warn("ValidateSparkPlan already injected, skipping."); + return; + } + try { + ByteBuddyAgent.install(); + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + TypeDescription typeDescription = TypePool.Default.of(contextClassLoader) + .describe("org.apache.spark.sql.execution.adaptive.ValidateSparkPlan$") + .resolve(); + new ByteBuddy() + .redefine(typeDescription, ClassFileLocator.ForClassLoader.of(contextClassLoader)) + .method(named("apply")) + .intercept(MethodDelegation.to(ValidateSparkPlanApplyInterceptor.class)) + .make() + .load(contextClassLoader, ClassLoadingStrategy.Default.INJECTION); + logger.info("Successfully injected ValidateSparkPlan."); + injected = true; + } catch (TypePool.Resolution.NoSuchTypeException e) { + logger.debug("No such type of ValidateSparkPlan", e); + } + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/InterceptedValidateSparkPlan.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/InterceptedValidateSparkPlan.scala new file mode 100644 index 000000000..61480bd3d --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/InterceptedValidateSparkPlan.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +object InterceptedValidateSparkPlan extends Logging { + + @sparkver("4.0") + def validate(plan: SparkPlan): Unit = { + import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec + import org.apache.spark.sql.execution.auron.plan.NativeRenameColumnsBase + import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec + import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec + import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec + import org.apache.spark.sql.catalyst.optimizer.BuildLeft + import org.apache.spark.sql.catalyst.optimizer.BuildRight + + plan match { + case b: BroadcastHashJoinExec => + val (buildPlan, probePlan) = b.buildSide match { + case BuildLeft => (b.left, b.right) + case BuildRight => (b.right, b.left) + } + if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) { + validate(buildPlan) + } + validate(probePlan) + + case b: NativeBroadcastJoinExec => // same as non-native BHJ + var (buildPlan, probePlan) = b.buildSide match { + case BuildLeft => (b.left, b.right) + case BuildRight => (b.right, b.left) + } + if (buildPlan.isInstanceOf[NativeRenameColumnsBase]) { + buildPlan = buildPlan.children.head + } + if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) { + validate(buildPlan) + } + validate(probePlan) + + case b: BroadcastNestedLoopJoinExec => + val (buildPlan, probePlan) = b.buildSide match { + case BuildLeft => (b.left, b.right) + case BuildRight => (b.right, b.left) + } + if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) { + validate(buildPlan) + } + validate(probePlan) + case q: BroadcastQueryStageExec => errorOnInvalidBroadcastQueryStage(q) + case _ => plan.children.foreach(validate) + } + } + + @sparkver("4.0") + private def errorOnInvalidBroadcastQueryStage(plan: SparkPlan): Unit = { + import org.apache.spark.sql.execution.adaptive.InvalidAQEPlanException + throw InvalidAQEPlanException("Invalid broadcast query stage", plan) + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala new file mode 100644 index 000000000..5f65c5533 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala @@ -0,0 +1,733 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import java.io.File +import java.util.UUID + +import org.apache.commons.lang3.reflect.FieldUtils +import org.apache.spark.OneToOneDependency +import org.apache.spark.ShuffleDependency +import org.apache.spark.SparkEnv +import org.apache.spark.SparkException +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.auron.AuronConverters.ForceNativeExecutionWrapperBase +import org.apache.spark.sql.auron.NativeConverters.NativeExprWrapperBase +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Generator +import org.apache.spark.sql.catalyst.expressions.Like +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.StringSplit +import org.apache.spark.sql.catalyst.expressions.TaggingExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction +import org.apache.spark.sql.catalyst.expressions.aggregate.First +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.CoalescedPartitionSpec +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.PartialMapperPartitionSpec +import org.apache.spark.sql.execution.PartialReducerPartitionSpec +import org.apache.spark.sql.execution.ShuffledRowRDD +import org.apache.spark.sql.execution.ShufflePartitionSpec +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.auron.plan._ +import org.apache.spark.sql.execution.auron.plan.ConvertToNativeExec +import org.apache.spark.sql.execution.auron.plan.NativeAggBase +import org.apache.spark.sql.execution.auron.plan.NativeAggBase.AggExecMode +import org.apache.spark.sql.execution.auron.plan.NativeAggExec +import org.apache.spark.sql.execution.auron.plan.NativeBroadcastExchangeBase +import org.apache.spark.sql.execution.auron.plan.NativeBroadcastExchangeExec +import org.apache.spark.sql.execution.auron.plan.NativeExpandBase +import org.apache.spark.sql.execution.auron.plan.NativeExpandExec +import org.apache.spark.sql.execution.auron.plan.NativeFilterBase +import org.apache.spark.sql.execution.auron.plan.NativeFilterExec +import org.apache.spark.sql.execution.auron.plan.NativeGenerateBase +import org.apache.spark.sql.execution.auron.plan.NativeGenerateExec +import org.apache.spark.sql.execution.auron.plan.NativeGlobalLimitBase +import org.apache.spark.sql.execution.auron.plan.NativeGlobalLimitExec +import org.apache.spark.sql.execution.auron.plan.NativeLocalLimitBase +import org.apache.spark.sql.execution.auron.plan.NativeLocalLimitExec +import org.apache.spark.sql.execution.auron.plan.NativeOrcScanExec +import org.apache.spark.sql.execution.auron.plan.NativeParquetInsertIntoHiveTableBase +import org.apache.spark.sql.execution.auron.plan.NativeParquetInsertIntoHiveTableExec +import org.apache.spark.sql.execution.auron.plan.NativeParquetScanBase +import org.apache.spark.sql.execution.auron.plan.NativeParquetScanExec +import org.apache.spark.sql.execution.auron.plan.NativeProjectBase +import org.apache.spark.sql.execution.auron.plan.NativeRenameColumnsBase +import org.apache.spark.sql.execution.auron.plan.NativeShuffleExchangeBase +import org.apache.spark.sql.execution.auron.plan.NativeShuffleExchangeExec +import org.apache.spark.sql.execution.auron.plan.NativeSortBase +import org.apache.spark.sql.execution.auron.plan.NativeSortExec +import org.apache.spark.sql.execution.auron.plan.NativeTakeOrderedBase +import org.apache.spark.sql.execution.auron.plan.NativeTakeOrderedExec +import org.apache.spark.sql.execution.auron.plan.NativeUnionBase +import org.apache.spark.sql.execution.auron.plan.NativeUnionExec +import org.apache.spark.sql.execution.auron.plan.NativeWindowBase +import org.apache.spark.sql.execution.auron.plan.NativeWindowExec +import org.apache.spark.sql.execution.auron.shuffle.{AuronBlockStoreShuffleReaderBase, AuronRssShuffleManagerBase, RssPartitionWriterBase} +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} +import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec +import org.apache.spark.sql.execution.joins.auron.plan.NativeShuffledHashJoinExecProvider +import org.apache.spark.sql.execution.joins.auron.plan.NativeSortMergeJoinExecProvider +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StringType +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.FileSegment + +import org.apache.auron.{protobuf => pb, sparkver} + +class ShimsImpl extends Shims with Logging { + + @sparkver("4.0") + override def shimVersion: String = "spark-4.0" + + @sparkver("4.0") + override def initExtension(): Unit = { + ValidateSparkPlanInjector.inject() + + if (AuronConf.FORCE_SHUFFLED_HASH_JOIN.booleanConf()) { + ForceApplyShuffledHashJoinInjector.inject() + } + + // disable MultiCommutativeOp suggested in spark3.4+ + if (shimVersion >= "spark-3.4") { + val confName = "spark.sql.analyzer.canonicalization.multiCommutativeOpMemoryOptThreshold" + SparkEnv.get.conf.set(confName, Int.MaxValue.toString) + } + } + + override def createConvertToNativeExec(child: SparkPlan): ConvertToNativeBase = + ConvertToNativeExec(child) + + override def createNativeAggExec( + execMode: AggExecMode, + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + child: SparkPlan): NativeAggBase = + NativeAggExec( + execMode, + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + child) + + override def createNativeBroadcastExchangeExec( + mode: BroadcastMode, + child: SparkPlan): NativeBroadcastExchangeBase = + NativeBroadcastExchangeExec(mode, child) + + override def createNativeBroadcastJoinExec( + left: SparkPlan, + right: SparkPlan, + outputPartitioning: Partitioning, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + broadcastSide: BroadcastSide): NativeBroadcastJoinBase = + NativeBroadcastJoinExec( + left, + right, + outputPartitioning, + leftKeys, + rightKeys, + joinType, + broadcastSide) + + override def createNativeSortMergeJoinExec( + left: SparkPlan, + right: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + isSkewJoin: Boolean): NativeSortMergeJoinBase = + NativeSortMergeJoinExecProvider.provide( + left, + right, + leftKeys, + rightKeys, + joinType, + isSkewJoin) + + override def createNativeShuffledHashJoinExec( + left: SparkPlan, + right: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + isSkewJoin: Boolean): SparkPlan = + NativeShuffledHashJoinExecProvider.provide( + left, + right, + leftKeys, + rightKeys, + joinType, + buildSide, + isSkewJoin) + + override def createNativeExpandExec( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: SparkPlan): NativeExpandBase = + NativeExpandExec(projections, output, child) + + override def createNativeFilterExec(condition: Expression, child: SparkPlan): NativeFilterBase = + NativeFilterExec(condition, child) + + override def createNativeGenerateExec( + generator: Generator, + requiredChildOutput: Seq[Attribute], + outer: Boolean, + generatorOutput: Seq[Attribute], + child: SparkPlan): NativeGenerateBase = + NativeGenerateExec(generator, requiredChildOutput, outer, generatorOutput, child) + + override def createNativeGlobalLimitExec(limit: Long, child: SparkPlan): NativeGlobalLimitBase = + NativeGlobalLimitExec(limit, child) + + override def createNativeLocalLimitExec(limit: Long, child: SparkPlan): NativeLocalLimitBase = + NativeLocalLimitExec(limit, child) + + override def createNativeParquetInsertIntoHiveTableExec( + cmd: InsertIntoHiveTable, + child: SparkPlan): NativeParquetInsertIntoHiveTableBase = + NativeParquetInsertIntoHiveTableExec(cmd, child) + + override def createNativeParquetScanExec( + basedFileScan: FileSourceScanExec): NativeParquetScanBase = + NativeParquetScanExec(basedFileScan) + + override def createNativeOrcScanExec(basedFileScan: FileSourceScanExec): NativeOrcScanBase = + NativeOrcScanExec(basedFileScan) + + override def createNativeProjectExec( + projectList: Seq[NamedExpression], + child: SparkPlan): NativeProjectBase = + NativeProjectExecProvider.provide(projectList, child) + + override def createNativeRenameColumnsExec( + child: SparkPlan, + newColumnNames: Seq[String]): NativeRenameColumnsBase = + NativeRenameColumnsExecProvider.provide(child, newColumnNames) + + override def createNativeShuffleExchangeExec( + outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase = + NativeShuffleExchangeExec(outputPartitioning, child, shuffleOrigin) + + override def createNativeSortExec( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan): NativeSortBase = + NativeSortExec(sortOrder, global, child) + + override def createNativeTakeOrderedExec( + limit: Long, + sortOrder: Seq[SortOrder], + child: SparkPlan): NativeTakeOrderedBase = + NativeTakeOrderedExec(limit, sortOrder, child) + + override def createNativePartialTakeOrderedExec( + limit: Long, + sortOrder: Seq[SortOrder], + child: SparkPlan, + metrics: Map[String, SQLMetric]): NativePartialTakeOrderedBase = + NativePartialTakeOrderedExec(limit, sortOrder, child, metrics) + + override def createNativeUnionExec( + children: Seq[SparkPlan], + output: Seq[Attribute]): NativeUnionBase = + NativeUnionExec(children, output) + + override def createNativeWindowExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + groupLimit: Option[Int], + outputWindowCols: Boolean, + child: SparkPlan): NativeWindowBase = + NativeWindowExec(windowExpression, partitionSpec, orderSpec, groupLimit, child) + + override def createNativeParquetSinkExec( + sparkSession: SparkSession, + table: CatalogTable, + partition: Map[String, Option[String]], + child: SparkPlan, + metrics: Map[String, SQLMetric]): NativeParquetSinkBase = + NativeParquetSinkExec(sparkSession, table, partition, child, metrics) + + override def getUnderlyingBroadcast(plan: SparkPlan): BroadcastExchangeLike = { + plan match { + case exec: BroadcastExchangeLike => exec + case exec: UnaryExecNode => getUnderlyingBroadcast(exec.child) + case exec: BroadcastQueryStageExec => getUnderlyingBroadcast(exec.broadcast) + case exec: ReusedExchangeExec => getUnderlyingBroadcast(exec.child) + } + } + + override def isNative(plan: SparkPlan): Boolean = + plan match { + case _: NativeSupports => true + case plan if isAQEShuffleRead(plan) => isNative(plan.children.head) + case plan: QueryStageExec => isNative(plan.plan) + case plan: ReusedExchangeExec => isNative(plan.child) + case _ => false + } + + override def getUnderlyingNativePlan(plan: SparkPlan): NativeSupports = { + plan match { + case plan: NativeSupports => plan + case plan if isAQEShuffleRead(plan) => getUnderlyingNativePlan(plan.children.head) + case plan: QueryStageExec => getUnderlyingNativePlan(plan.plan) + case plan: ReusedExchangeExec => getUnderlyingNativePlan(plan.child) + case _ => throw new RuntimeException("unreachable: plan is not native") + } + } + + override def executeNative(plan: SparkPlan): NativeRDD = { + plan match { + case plan: NativeSupports => plan.executeNative() + case plan if isAQEShuffleRead(plan) => executeNativeAQEShuffleReader(plan) + case plan: QueryStageExec => executeNative(plan.plan) + case plan: ReusedExchangeExec => executeNative(plan.child) + case _ => + throw new SparkException(s"Underlying plan is not NativeSupports: ${plan}") + } + } + + override def isQueryStageInput(plan: SparkPlan): Boolean = { + plan.isInstanceOf[QueryStageExec] + } + + override def isShuffleQueryStageInput(plan: SparkPlan): Boolean = { + plan.isInstanceOf[ShuffleQueryStageExec] + } + + override def getChildStage(plan: SparkPlan): SparkPlan = + plan.asInstanceOf[QueryStageExec].plan + + override def simpleStringWithNodeId(plan: SparkPlan): String = plan.simpleStringWithNodeId() + + override def setLogicalLink(exec: SparkPlan, basedExec: SparkPlan): SparkPlan = { + basedExec.logicalLink.foreach(logicalLink => exec.setLogicalLink(logicalLink)) + exec + } + + override def getRDDShuffleReadFull(rdd: RDD[_]): Boolean = true + + override def setRDDShuffleReadFull(rdd: RDD[_], shuffleReadFull: Boolean): Unit = {} + + override def createFileSegment( + file: File, + offset: Long, + length: Long, + numRecords: Long): FileSegment = new FileSegment(file, offset, length) + + @sparkver("4.0") + override def commit( + dep: ShuffleDependency[_, _, _], + shuffleBlockResolver: IndexShuffleBlockResolver, + tempDataFile: File, + mapId: Long, + partitionLengths: Array[Long], + dataSize: Long, + context: TaskContext): MapStatus = { + + val checksums = Array[Long]() + shuffleBlockResolver.writeMetadataFileAndCommit( + dep.shuffleId, + mapId, + partitionLengths, + checksums, + tempDataFile) + MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId) + } + + override def getRssPartitionWriter( + handle: ShuffleHandle, + mapId: Int, + metrics: ShuffleWriteMetricsReporter, + numPartitions: Int): Option[RssPartitionWriterBase] = None + + override def getMapStatus( + shuffleServerId: BlockManagerId, + partitionLengthMap: Array[Long], + mapId: Long): MapStatus = + MapStatus.apply(shuffleServerId, partitionLengthMap, mapId) + + override def getShuffleWriteExec( + input: pb.PhysicalPlanNode, + nativeOutputPartitioning: pb.PhysicalRepartition.Builder): pb.PhysicalPlanNode = { + + if (SparkEnv.get.shuffleManager.isInstanceOf[AuronRssShuffleManagerBase]) { + return pb.PhysicalPlanNode + .newBuilder() + .setRssShuffleWriter( + pb.RssShuffleWriterExecNode + .newBuilder() + .setInput(input) + .setOutputPartitioning(nativeOutputPartitioning) + .buildPartial() + ) // shuffleId is not set at the moment, will be set in ShuffleWriteProcessor + .build() + } + + pb.PhysicalPlanNode + .newBuilder() + .setShuffleWriter( + pb.ShuffleWriterExecNode + .newBuilder() + .setInput(input) + .setOutputPartitioning(nativeOutputPartitioning) + .buildPartial() + ) // shuffleId is not set at the moment, will be set in ShuffleWriteProcessor + .build() + } + + override def convertMoreExprWithFallback( + e: Expression, + isPruningExpr: Boolean, + fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = { + e match { + case StringSplit(str, pat @ Literal(_, StringType), Literal(-1, IntegerType)) + // native StringSplit implementation does not support regex, so only most frequently + // used cases without regex are supported + if Seq(",", ", ", ":", ";", "#", "@", "_", "-", "\\|", "\\.").contains(pat.value) => + val nativePat = pat.value match { + case "\\|" => "|" + case "\\." => "." + case other => other + } + Some( + pb.PhysicalExprNode + .newBuilder() + .setScalarFunction( + pb.PhysicalScalarFunctionNode + .newBuilder() + .setFun(pb.ScalarFunction.SparkExtFunctions) + .setName("StringSplit") + .addArgs(NativeConverters.convertExprWithFallback(str, isPruningExpr, fallback)) + .addArgs(NativeConverters + .convertExprWithFallback(Literal(nativePat), isPruningExpr, fallback)) + .setReturnType(NativeConverters.convertDataType(StringType))) + .build()) + + case e: TaggingExpression => + Some(NativeConverters.convertExprWithFallback(e.child, isPruningExpr, fallback)) + case e => + convertPromotePrecision(e, isPruningExpr, fallback) match { + case Some(v) => return Some(v) + case None => + } + convertBloomFilterMightContain(e, isPruningExpr, fallback) match { + case Some(v) => return Some(v) + case None => + } + None + } + } + + override def getLikeEscapeChar(expr: Expression): Char = { + expr.asInstanceOf[Like].escapeChar + } + + override def convertMoreAggregateExpr(e: AggregateExpression): Option[pb.PhysicalExprNode] = { + assert(getAggregateExpressionFilter(e).isEmpty) + + e.aggregateFunction match { + case First(child, ignoresNull) => + val aggExpr = pb.PhysicalAggExprNode + .newBuilder() + .setReturnType(NativeConverters.convertDataType(e.dataType)) + .setAggFunction(if (ignoresNull) { + pb.AggFunction.FIRST_IGNORES_NULL + } else { + pb.AggFunction.FIRST + }) + .addChildren(NativeConverters.convertExpr(child)) + Some(pb.PhysicalExprNode.newBuilder().setAggExpr(aggExpr).build()) + + case agg => + convertBloomFilterAgg(agg) match { + case Some(aggExpr) => + return Some( + pb.PhysicalExprNode + .newBuilder() + .setAggExpr(aggExpr) + .build()) + case None => + } + None + } + } + + override def getAggregateExpressionFilter(expr: Expression): Option[Expression] = { + expr.asInstanceOf[AggregateExpression].filter + } + + @sparkver("4.0") + private def isAQEShuffleRead(exec: SparkPlan): Boolean = { + import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec + exec.isInstanceOf[AQEShuffleReadExec] + } + + @sparkver("4.0") + private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = { + import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec + import org.apache.spark.sql.execution.CoalescedMapperPartitionSpec + + exec match { + case AQEShuffleReadExec(child, _) if isNative(child) => + val shuffledRDD = exec.execute().asInstanceOf[ShuffledRowRDD] + val shuffleHandle = shuffledRDD.dependency.shuffleHandle + + val inputRDD = executeNative(child) + val nativeShuffle = getUnderlyingNativePlan(child).asInstanceOf[NativeShuffleExchangeExec] + val nativeSchema: pb.Schema = nativeShuffle.nativeSchema + + val requiredMetrics = nativeShuffle.readMetrics ++ + nativeShuffle.metrics.filterKeys(_ == "shuffle_read_total_time") + val metrics = MetricNode( + requiredMetrics, + inputRDD.metrics :: Nil, + Some({ + case ("output_rows", v) => + val tempMetrics = TaskContext.get.taskMetrics().createTempShuffleReadMetrics() + new SQLShuffleReadMetricsReporter(tempMetrics, requiredMetrics).incRecordsRead(v) + TaskContext.get().taskMetrics().mergeShuffleReadMetrics() + case ("elapsed_compute", v) => requiredMetrics("shuffle_read_total_time") += v + case _ => + })) + + new NativeRDD( + shuffledRDD.sparkContext, + metrics, + shuffledRDD.partitions, + shuffledRDD.partitioner, + new OneToOneDependency(shuffledRDD) :: Nil, + true, + (partition, taskContext) => { + + // use reflection to get partitionSpec because ShuffledRowRDDPartition is private + val sqlMetricsReporter = taskContext.taskMetrics().createTempShuffleReadMetrics() + val spec = FieldUtils + .readDeclaredField(partition, "spec", true) + .asInstanceOf[ShufflePartitionSpec] + val reader = spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + SparkEnv.get.shuffleManager.getReader( + shuffleHandle, + startReducerIndex, + endReducerIndex, + taskContext, + sqlMetricsReporter) + + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager.getReader( + shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + taskContext, + sqlMetricsReporter) + + case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) => + SparkEnv.get.shuffleManager.getReader( + shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + taskContext, + sqlMetricsReporter) + + case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager.getReader( + shuffleHandle, + mapIndex, + mapIndex + 1, + startReducerIndex, + endReducerIndex, + taskContext, + sqlMetricsReporter) + } + + // store fetch iterator in jni resource before native compute + val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}" + JniBridge.resourcesMap.put( + jniResourceId, + () => { + reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc() + }) + + pb.PhysicalPlanNode + .newBuilder() + .setIpcReader( + pb.IpcReaderExecNode + .newBuilder() + .setSchema(nativeSchema) + .setNumPartitions(shuffledRDD.getNumPartitions) + .setIpcProviderResourceId(jniResourceId) + .build()) + .build() + }) + } + } + + override def convertMoreSparkPlan(exec: SparkPlan): Option[SparkPlan] = { + exec match { + case exec if isAQEShuffleRead(exec) && isNative(exec) => + Some(ForceNativeExecutionWrapper(AuronConverters.addRenameColumnsExec(exec))) + case _: ReusedExchangeExec if isNative(exec) => + Some(ForceNativeExecutionWrapper(AuronConverters.addRenameColumnsExec(exec))) + case _ => None + } + } + + @sparkver("4.0") + override def getSqlContext( + sparkPlan: org.apache.spark.sql.execution.SparkPlan): org.apache.spark.sql.SQLContext = + sparkPlan.session.sqlContext + + override def createNativeExprWrapper( + nativeExpr: pb.PhysicalExprNode, + dataType: DataType, + nullable: Boolean): Expression = { + NativeExprWrapper(nativeExpr, dataType, nullable) + } + + @sparkver("4.0") + override def getPartitionedFile( + partitionValues: InternalRow, + filePath: String, + offset: Long, + size: Long): PartitionedFile = { + import org.apache.hadoop.fs.Path + import org.apache.spark.paths.SparkPath + PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size) + } + + @sparkver("4.0") + override def getMinPartitionNum(sparkSession: SparkSession): Int = + sparkSession.sessionState.conf.filesMinPartitionNum + .getOrElse(sparkSession.sparkContext.defaultParallelism) + + @sparkver("4.0") + private def convertPromotePrecision( + e: Expression, + isPruningExpr: Boolean, + fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = None + + @sparkver("4.0") + private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = { + import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate + agg match { + case BloomFilterAggregate(child, estimatedNumItemsExpression, numBitsExpression, _, _) => + // ensure numBits is a power of 2 + val estimatedNumItems = + estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue() + val numBits = numBitsExpression.eval().asInstanceOf[Number].longValue() + val numBitsNextPowerOf2 = numBits match { + case 1 => 1L + case n => Integer.highestOneBit(n.toInt - 1) << 1 + } + Some( + pb.PhysicalAggExprNode + .newBuilder() + .setReturnType(NativeConverters.convertDataType(agg.dataType)) + .setAggFunction(pb.AggFunction.BLOOM_FILTER) + .addChildren(NativeConverters.convertExpr(child)) + .addChildren(NativeConverters.convertExpr(Literal(estimatedNumItems))) + .addChildren(NativeConverters.convertExpr(Literal(numBitsNextPowerOf2))) + .build()) + case _ => None + } + } + + @sparkver("4.0") + private def convertBloomFilterMightContain( + e: Expression, + isPruningExpr: Boolean, + fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = { + import org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain + e match { + case e: BloomFilterMightContain => + val uuid = UUID.randomUUID().toString + Some(NativeConverters.buildExprNode { + _.setBloomFilterMightContainExpr( + pb.BloomFilterMightContainExprNode + .newBuilder() + .setUuid(uuid) + .setBloomFilterExpr(NativeConverters + .convertExprWithFallback(e.bloomFilterExpression, isPruningExpr, fallback)) + .setValueExpr(NativeConverters + .convertExprWithFallback(e.valueExpression, isPruningExpr, fallback))) + }) + case _ => None + } + } + + @sparkver("4.0") + override def getAdaptiveInputPlan(exec: AdaptiveSparkPlanExec): SparkPlan = { + exec.inputPlan + } +} + +case class ForceNativeExecutionWrapper(override val child: SparkPlan) + extends ForceNativeExecutionWrapperBase(child) { + + @sparkver("4.0") + override def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + +case class NativeExprWrapper( + nativeExpr: pb.PhysicalExprNode, + override val dataType: DataType, + override val nullable: Boolean) + extends NativeExprWrapperBase(nativeExpr, dataType, nullable) { + + @sparkver("4.0") + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy() +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeExec.scala new file mode 100644 index 000000000..6ffe0a836 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeExec.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class ConvertToNativeExec(override val child: SparkPlan) extends ConvertToNativeBase(child) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala new file mode 100644 index 000000000..bde4bbd9a --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExprId +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.Final +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.auron.plan.NativeAggBase.AggExecMode +import org.apache.spark.sql.types.BinaryType + +import org.apache.auron.sparkver + +case class NativeAggExec( + execMode: AggExecMode, + theRequiredChildDistributionExpressions: Option[Seq[Expression]], + override val groupingExpressions: Seq[NamedExpression], + override val aggregateExpressions: Seq[AggregateExpression], + override val aggregateAttributes: Seq[Attribute], + theInitialInputBufferOffset: Int, + override val child: SparkPlan) + extends NativeAggBase( + execMode, + theRequiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + theInitialInputBufferOffset, + child) + with BaseAggregateExec { + + @sparkver("4.0") + override val requiredChildDistributionExpressions: Option[Seq[Expression]] = + theRequiredChildDistributionExpressions + + @sparkver("4.0") + override val initialInputBufferOffset: Int = theInitialInputBufferOffset + + override def output: Seq[Attribute] = + if (aggregateExpressions.map(_.mode).contains(Final)) { + groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + } else { + groupingExpressions.map(_.toAttribute) :+ + AttributeReference(NativeAggBase.AGG_BUF_COLUMN_NAME, BinaryType, nullable = false)( + ExprId.apply(NativeAggBase.AGG_BUF_COLUMN_EXPR_ID)) + } + + @sparkver("4.0") + override def isStreaming: Boolean = false + + @sparkver("4.0") + override def numShufflePartitions: Option[Int] = None + + override def resultExpressions: Seq[NamedExpression] = output + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeExec.scala new file mode 100644 index 000000000..d9fccfe43 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeExec.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import java.util.UUID + +import org.apache.spark.broadcast +import org.apache.spark.sql.auron.NativeSupports +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeBroadcastExchangeExec(mode: BroadcastMode, override val child: SparkPlan) + extends NativeBroadcastExchangeBase(mode, child) + with NativeSupports { + + override val getRunId: UUID = UUID.randomUUID() + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics("numOutputRows").value + Statistics(dataSize, Some(rowCount)) + } + + @transient + override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = { + relationFuturePromise.future + } + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeExpandExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeExpandExec.scala new file mode 100644 index 000000000..002249154 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeExpandExec.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeExpandExec( + projections: Seq[Seq[Expression]], + override val output: Seq[Attribute], + override val child: SparkPlan) + extends NativeExpandBase(projections, output, child) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeFilterExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeFilterExec.scala new file mode 100644 index 000000000..0f72420fb --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeFilterExec.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeFilterExec(condition: Expression, override val child: SparkPlan) + extends NativeFilterBase(condition, child) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGenerateExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGenerateExec.scala new file mode 100644 index 000000000..8c9c02c5e --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGenerateExec.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Generator +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeGenerateExec( + generator: Generator, + requiredChildOutput: Seq[Attribute], + outer: Boolean, + generatorOutput: Seq[Attribute], + override val child: SparkPlan) + extends NativeGenerateBase(generator, requiredChildOutput, outer, generatorOutput, child) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala new file mode 100644 index 000000000..1dea592e2 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeGlobalLimitExec(limit: Long, override val child: SparkPlan) + extends NativeGlobalLimitBase(limit, child) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala new file mode 100644 index 000000000..db6f5a0f6 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeLocalLimitExec(limit: Long, override val child: SparkPlan) + extends NativeLocalLimitBase(limit, child) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcScanExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcScanExec.scala new file mode 100644 index 000000000..06a63f6e9 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcScanExec.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.execution.FileSourceScanExec + +case class NativeOrcScanExec(basedFileScan: FileSourceScanExec) + extends NativeOrcScanBase(basedFileScan) { + + override def simpleString(maxFields: Int): String = + s"$nodeName (${basedFileScan.simpleString(maxFields)})" +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetInsertIntoHiveTableExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetInsertIntoHiveTableExec.scala new file mode 100644 index 000000000..3c5b7be4c --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetInsertIntoHiveTableExec.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.Row +import org.apache.spark.sql.auron.Shims +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable + +import org.apache.auron.sparkver + +case class NativeParquetInsertIntoHiveTableExec( + cmd: InsertIntoHiveTable, + override val child: SparkPlan) + extends NativeParquetInsertIntoHiveTableBase(cmd, child) { + + @sparkver("4.0") + override protected def getInsertIntoHiveTableCommand( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + metrics: Map[String, SQLMetric]): InsertIntoHiveTable = { + new AuronInsertIntoHiveTable40( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames, + metrics) + } + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + @sparkver("4.0") + class AuronInsertIntoHiveTable40( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + outerMetrics: Map[String, SQLMetric]) + extends { + private val insertIntoHiveTable = InsertIntoHiveTable( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames) + private val initPartitionColumns = insertIntoHiveTable.partitionColumns + private val initBucketSpec = insertIntoHiveTable.bucketSpec + private val initOptions = insertIntoHiveTable.options + private val initFileFormat = insertIntoHiveTable.fileFormat + private val initHiveTmpPath = insertIntoHiveTable.hiveTmpPath + + } + with InsertIntoHiveTable( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames, + initPartitionColumns, + initBucketSpec, + initOptions, + initFileFormat, + initHiveTmpPath) { + + override lazy val metrics: Map[String, SQLMetric] = outerMetrics + + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + val nativeParquetSink = + Shims.get.createNativeParquetSinkExec(sparkSession, table, partition, child, metrics) + super.run(sparkSession, nativeParquetSink) + } + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetScanExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetScanExec.scala new file mode 100644 index 000000000..a116a5430 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetScanExec.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.execution.FileSourceScanExec + +case class NativeParquetScanExec(basedFileScan: FileSourceScanExec) + extends NativeParquetScanBase(basedFileScan) { + + override def simpleString(maxFields: Int): String = + s"$nodeName (${basedFileScan.simpleString(maxFields)})" +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetSinkExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetSinkExec.scala new file mode 100644 index 000000000..cc44a9dbd --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeParquetSinkExec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric + +import org.apache.auron.sparkver + +case class NativeParquetSinkExec( + sparkSession: SparkSession, + table: CatalogTable, + partition: Map[String, Option[String]], + override val child: SparkPlan, + override val metrics: Map[String, SQLMetric]) + extends NativeParquetSinkBase(sparkSession, table, partition, child, metrics) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala new file mode 100644 index 000000000..60b151206 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric + +import org.apache.auron.sparkver + +case class NativePartialTakeOrderedExec( + limit: Long, + sortOrder: Seq[SortOrder], + override val child: SparkPlan, + override val metrics: Map[String, SQLMetric]) + extends NativePartialTakeOrderedBase(limit, sortOrder, child, metrics) { + + @sparkver("4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeProjectExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeProjectExecProvider.scala new file mode 100644 index 000000000..3274fd9eb --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeProjectExecProvider.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case object NativeProjectExecProvider { + @sparkver("4.0") + def provide(projectList: Seq[NamedExpression], child: SparkPlan): NativeProjectBase = { + import org.apache.spark.sql.execution.OrderPreservingUnaryExecNode + import org.apache.spark.sql.execution.PartitioningPreservingUnaryExecNode + + case class NativeProjectExec(projectList: Seq[NamedExpression], override val child: SparkPlan) + extends NativeProjectBase(projectList, child) + with PartitioningPreservingUnaryExecNode + with OrderPreservingUnaryExecNode { + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + override protected def outputExpressions = projectList + + override protected def orderingExpressions = child.outputOrdering + + override def nodeName: String = "NativeProject" + } + NativeProjectExec(projectList, child) + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeRenameColumnsExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeRenameColumnsExecProvider.scala new file mode 100644 index 000000000..eabb98727 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeRenameColumnsExecProvider.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case object NativeRenameColumnsExecProvider { + @sparkver("4.0") + def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = { + import org.apache.spark.sql.catalyst.expressions.NamedExpression + import org.apache.spark.sql.catalyst.expressions.SortOrder + import org.apache.spark.sql.execution.OrderPreservingUnaryExecNode + import org.apache.spark.sql.execution.PartitioningPreservingUnaryExecNode + + case class NativeRenameColumnsExec( + override val child: SparkPlan, + renamedColumnNames: Seq[String]) + extends NativeRenameColumnsBase(child, renamedColumnNames) + with PartitioningPreservingUnaryExecNode + with OrderPreservingUnaryExecNode { + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + override protected def outputExpressions: Seq[NamedExpression] = output + + override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering + } + NativeRenameColumnsExec(child, renamedColumnNames) + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala new file mode 100644 index 000000000..e6061adc0 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import scala.collection.mutable +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future + +import org.apache.spark._ +import org.apache.spark.rdd.MapPartitionsRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.shuffle.ShuffleWriteProcessor +import org.apache.spark.sql.auron.NativeHelper +import org.apache.spark.sql.auron.NativeRDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.auron.shuffle.AuronRssShuffleWriterBase +import org.apache.spark.sql.execution.auron.shuffle.AuronShuffleWriterBase +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter +import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter + +import org.apache.auron.sparkver + +case class NativeShuffleExchangeExec( + override val outputPartitioning: Partitioning, + override val child: SparkPlan, + _shuffleOrigin: Option[Any] = None) + extends NativeShuffleExchangeBase(outputPartitioning, child) { + + // NOTE: coordinator can be null after serialization/deserialization, + // e.g. it can be null on the Executor side + lazy val writeMetrics: Map[String, SQLMetric] = (mutable.LinkedHashMap[String, SQLMetric]() ++ + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) ++ + mutable.LinkedHashMap( + NativeHelper + .getDefaultNativeMetrics(sparkContext) + .filterKeys(Set( + "stage_id", + "mem_spill_count", + "mem_spill_size", + "mem_spill_iotime", + "disk_spill_size", + "disk_spill_iotime", + "shuffle_write_total_time", + "shuffle_read_total_time")) + .toSeq: _*)).toMap + + lazy val readMetrics: Map[String, SQLMetric] = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + + override lazy val metrics: Map[String, SQLMetric] = + (mutable.LinkedHashMap[String, SQLMetric]() ++ + readMetrics ++ + writeMetrics ++ + Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions"))).toMap + + // 'mapOutputStatisticsFuture' is only needed when enable AQE. + @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (inputRDD.getNumPartitions == 0) { + Future.successful(null) + } else { + sparkContext + .submitMapStage(shuffleDependency) + .map(stat => new MapOutputStatistics(stat.shuffleId, stat.bytesByPartitionId)) + } + } + + override def numMappers: Int = shuffleDependency.rdd.getNumPartitions + + override def numPartitions: Int = shuffleDependency.partitioner.numPartitions + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = { + new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + + /** + * Caches the created ShuffleRowRDD so we can reuse that. + */ + private var cachedShuffleRDD: ShuffledRowRDD = _ + + override protected def doExecuteNonNative(): RDD[InternalRow] = { + // Returns the same ShuffleRowRDD if this plan is used by multiple plans. + if (cachedShuffleRDD == null) { + cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics) + } + cachedShuffleRDD + } + + override def createNativeShuffleWriteProcessor( + metrics: Map[String, SQLMetric], + numPartitions: Int): ShuffleWriteProcessor = { + + new ShuffleWriteProcessor { + override protected def createMetricsReporter( + context: TaskContext): ShuffleWriteMetricsReporter = { + new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics) + } + + override def write( + inputs: Iterator[_], + dep: ShuffleDependency[_, _, _], + mapId: Long, + mapIndex: Int, + context: TaskContext): MapStatus = super.write(inputs, dep, mapId, mapIndex, context) + + def write( + rdd: RDD[_], + dep: ShuffleDependency[_, _, _], + mapId: Long, + context: TaskContext, + partition: Partition): MapStatus = { + + val writer = SparkEnv.get.shuffleManager.getWriter( + dep.shuffleHandle, + mapId, + context, + createMetricsReporter(context)) + + writer match { + case writer: AuronRssShuffleWriterBase[_, _] => + writer.nativeRssShuffleWrite( + rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[NativeRDD], + dep, + mapId.toInt, + context, + partition, + numPartitions) + + case writer: AuronShuffleWriterBase[_, _] => + writer.nativeShuffleWrite( + rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[NativeRDD], + dep, + mapId.toInt, + context, + partition) + } + writer.stop(true).get + } + } + } + + // for databricks testing + val causedBroadcastJoinBuildOOM = false + + @sparkver("3.5 / 4.0") + override def advisoryPartitionSize: Option[Long] = None + + // If users specify the num partitions via APIs like `repartition`, we shouldn't change it. + // For `SinglePartition`, it requires exactly one partition and we can't change it either. + @sparkver("3.0") + override def canChangeNumPartitions: Boolean = + outputPartitioning != SinglePartition + + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5 / 4.0") + override def shuffleOrigin = { + import org.apache.spark.sql.execution.exchange.ShuffleOrigin; + _shuffleOrigin.get.asInstanceOf[ShuffleOrigin] + } + + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + @sparkver("3.0 / 3.1") + override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = + copy(child = newChildren.head) + + override def shuffleId: Int = ??? +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala new file mode 100644 index 000000000..26809d116 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeSortExec( + sortOrder: Seq[SortOrder], + global: Boolean, + override val child: SparkPlan) + extends NativeSortBase(sortOrder, global, child) { + + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + @sparkver("3.0 / 3.1") + override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = + copy(child = newChildren.head) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala new file mode 100644 index 000000000..3b820e88b --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeTakeOrderedExec( + limit: Long, + sortOrder: Seq[SortOrder], + override val child: SparkPlan) + extends NativeTakeOrderedBase(limit, sortOrder, child) { + + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + @sparkver("3.0 / 3.1") + override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = + copy(child = newChildren.head) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala new file mode 100644 index 000000000..130697b78 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeUnionExec( + override val children: Seq[SparkPlan], + override val output: Seq[Attribute]) + extends NativeUnionBase(children, output) { + + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + copy(children = newChildren) + + @sparkver("3.0 / 3.1") + override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = + copy(children = newChildren) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala new file mode 100644 index 000000000..d14647c0b --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.auron.sparkver + +case class NativeWindowExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + groupLimit: Option[Int], + override val child: SparkPlan) + extends NativeWindowBase(windowExpression, partitionSpec, orderSpec, groupLimit, child) { + + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + @sparkver("3.0 / 3.1") + override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = + copy(child = newChildren.head) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronBlockStoreShuffleReader.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronBlockStoreShuffleReader.scala new file mode 100644 index 000000000..270d17b39 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronBlockStoreShuffleReader.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.shuffle + +import java.io.InputStream + +import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} + +import org.apache.auron.sparkver + +class AuronBlockStoreShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + extends AuronBlockStoreShuffleReaderBase[K, C](handle, context) + with Logging { + + override def readBlocks(): Iterator[InputStream] = { + @sparkver("4.0") + def fetchIterator = new ShuffleBlockFetcherIterator( + context, + blockManager.blockStoreClient, + blockManager, + mapOutputTracker, + blocksByAddress, + (_, inputStream) => inputStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + false, // checksums not supported + "ChecksumAlgorithmsNotSupported", + readMetrics, + fetchContinuousBlocksInBatch).toCompletionIterator.map(_._2) + + fetchIterator + } + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams( + CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol + if (shouldBatchFetch && !doBatchFetch) { + logDebug( + "The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol.") + } + doBatchFetch + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleManagerBase.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleManagerBase.scala new file mode 100644 index 000000000..ab9aabc3a --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleManagerBase.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.shuffle + +import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle._ +import org.apache.spark.sql.execution.auron.shuffle.AuronShuffleDependency.isArrowShuffle + +import org.apache.auron.sparkver + +abstract class AuronRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManager with Logging { + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle + + override def unregisterShuffle(shuffleId: Int): Boolean + + def getAuronRssShuffleReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): AuronRssShuffleReaderBase[K, C] + + def getAuronRssShuffleReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): AuronRssShuffleReaderBase[K, C] + + def getRssShuffleReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] + + def getRssShuffleReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] + + def getAuronRssShuffleWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): AuronRssShuffleWriterBase[K, V] + + def getRssShuffleWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] + + @sparkver("4.0") + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + + if (isArrowShuffle(handle)) { + getAuronRssShuffleReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + context, + metrics) + } else { + getRssShuffleReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + context, + metrics) + } + } + + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + + if (isArrowShuffle(handle)) { + getAuronRssShuffleWriter(handle, mapId, context, metrics) + } else { + getRssShuffleWriter(handle, mapId, context, metrics) + } + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleManager.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleManager.scala new file mode 100644 index 000000000..852136213 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleManager.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.shuffle + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.sql.execution.auron.shuffle.AuronShuffleDependency.isArrowShuffle + +import org.apache.auron.sparkver + +class AuronShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + val sortShuffleManager = new SortShuffleManager(conf) + + // disable other off-heap memory usages + System.setProperty("spark.memory.offHeap.enabled", "false") + System.setProperty("io.netty.maxDirectMemory", "0") + System.setProperty("io.netty.noPreferDirect", "true") + System.setProperty("io.netty.noUnsafe", "true") + + if (!conf.getBoolean("spark.shuffle.spill", defaultValue = true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } + + override val shuffleBlockResolver: ShuffleBlockResolver = + sortShuffleManager.shuffleBlockResolver + + /** + * (override) Obtains a [[ShuffleHandle]] to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + sortShuffleManager.registerShuffle(shuffleId, dependency) + } + + @sparkver("4.0") + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + + if (isArrowShuffle(handle)) { + val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] + + @sparkver("4.0") + def shuffleMergeFinalized = baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked + + val (blocksByAddress, canEnableBatchFetch) = + if (shuffleMergeFinalized) { + val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + (res.iter, res.enableBatchFetch) + } else { + val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + (address, true) + } + + new AuronBlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress.map(tup => (tup._1, tup._2.toSeq)), + context, + metrics, + SparkEnv.get.blockManager, + SparkEnv.get.mapOutputTracker, + shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)) + } else { + sortShuffleManager.getReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + context, + metrics) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + + if (isArrowShuffle(handle)) { + new AuronShuffleWriter(metrics) + } else { + sortShuffleManager.getWriter(handle, mapId, context, metrics) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + sortShuffleManager.unregisterShuffle(shuffleId) + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleWriter.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleWriter.scala new file mode 100644 index 000000000..2394ba132 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronShuffleWriter.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.shuffle + +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter + +import org.apache.auron.sparkver + +class AuronShuffleWriter[K, V](metrics: ShuffleWriteMetricsReporter) + extends AuronShuffleWriterBase[K, V](metrics) { + + @sparkver("4.0") + override def getPartitionLengths(): Array[Long] = partitionLengths +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala new file mode 100644 index 000000000..85d2564b6 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.joins.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.auron.plan.BroadcastLeft +import org.apache.spark.sql.execution.auron.plan.BroadcastRight +import org.apache.spark.sql.execution.auron.plan.BroadcastSide +import org.apache.spark.sql.execution.auron.plan.NativeBroadcastJoinBase +import org.apache.spark.sql.execution.joins.HashJoin + +import org.apache.auron.sparkver + +case class NativeBroadcastJoinExec( + override val left: SparkPlan, + override val right: SparkPlan, + override val outputPartitioning: Partitioning, + override val leftKeys: Seq[Expression], + override val rightKeys: Seq[Expression], + override val joinType: JoinType, + broadcastSide: BroadcastSide) + extends NativeBroadcastJoinBase( + left, + right, + outputPartitioning, + leftKeys, + rightKeys, + joinType, + broadcastSide) + with HashJoin { + + override val condition: Option[Expression] = None + + @sparkver("4.0") + override def buildSide: org.apache.spark.sql.catalyst.optimizer.BuildSide = + broadcastSide match { + case BroadcastLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft + case BroadcastRight => org.apache.spark.sql.catalyst.optimizer.BuildRight + } + + @sparkver("4.0") + override def requiredChildDistribution = { + import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution + import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution + import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode + + def mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false) + broadcastSide match { + case BroadcastLeft => + BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil + case BroadcastRight => + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } + } + + override def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression] = + HashJoin.rewriteKeyExpr(exprs) + + @sparkver("4.0") + override def supportCodegen: Boolean = false + + @sparkver("4.0") + override def inputRDDs() = { + throw new NotImplementedError("NativeBroadcastJoin dose not support codegen") + } + + @sparkver("4.0") + override protected def prepareRelation( + ctx: org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext) + : org.apache.spark.sql.execution.joins.HashedRelationInfo = { + throw new NotImplementedError("NativeBroadcastJoin dose not support codegen") + } + + @sparkver("4.0") + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SparkPlan = + copy(left = newLeft, right = newRight) +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala new file mode 100644 index 000000000..d24f2d6d9 --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.joins.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.auron.plan.BuildSide +import org.apache.spark.sql.execution.auron.plan.NativeShuffledHashJoinBase +import org.apache.spark.sql.execution.joins.HashJoin + +import org.apache.auron.sparkver + +case object NativeShuffledHashJoinExecProvider { + + @sparkver("4.0") + def provide( + left: SparkPlan, + right: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + isSkewJoin: Boolean): NativeShuffledHashJoinBase = { + + import org.apache.spark.rdd.RDD + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext + + case class NativeShuffledHashJoinExec( + override val left: SparkPlan, + override val right: SparkPlan, + override val leftKeys: Seq[Expression], + override val rightKeys: Seq[Expression], + override val joinType: JoinType, + buildSide: BuildSide, + skewJoin: Boolean) + extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys, joinType, buildSide) + with org.apache.spark.sql.execution.joins.ShuffledJoin { + + override def condition: Option[Expression] = None + + override def isSkewJoin: Boolean = false + + override def supportCodegen: Boolean = false + + override def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression] = + HashJoin.rewriteKeyExpr(exprs) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + throw new NotImplementedError("NativeShuffledHash dose not support codegen") + } + + override protected def doProduce(ctx: CodegenContext): String = { + throw new NotImplementedError("NativeShuffledHash dose not support codegen") + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SparkPlan = + copy(left = newLeft, right = newRight) + + override def nodeName: String = + "NativeShuffledHashJoin" + (if (skewJoin) "(skew=true)" else "") + } + NativeShuffledHashJoinExec(left, right, leftKeys, rightKeys, joinType, buildSide, isSkewJoin) + } +} diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeSortMergeJoinExecProvider.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeSortMergeJoinExecProvider.scala new file mode 100644 index 000000000..e59ca8f6e --- /dev/null +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeSortMergeJoinExecProvider.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.joins.auron.plan + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.auron.plan.NativeSortMergeJoinBase + +import org.apache.auron.sparkver + +case object NativeSortMergeJoinExecProvider { + + @sparkver("4.0") + def provide( + left: SparkPlan, + right: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + skewJoin: Boolean): NativeSortMergeJoinBase = { + + import org.apache.spark.rdd.RDD + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext + + case class NativeSortMergeJoinExec( + override val left: SparkPlan, + override val right: SparkPlan, + override val leftKeys: Seq[Expression], + override val rightKeys: Seq[Expression], + override val joinType: JoinType, + skewJoin: Boolean) + extends NativeSortMergeJoinBase(left, right, leftKeys, rightKeys, joinType) + with org.apache.spark.sql.execution.joins.ShuffledJoin { + + override def condition: Option[Expression] = None + + override def isSkewJoin: Boolean = false + + override def supportCodegen: Boolean = false + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + throw new NotImplementedError("NativeSortMergeJoin dose not support codegen") + } + + override protected def doProduce(ctx: CodegenContext): String = { + throw new NotImplementedError("NativeSortMergeJoin dose not support codegen") + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SparkPlan = + copy(left = newLeft, right = newRight) + + override def nodeName: String = + "NativeSortMergeJoin" + (if (skewJoin) "(skew=true)" else "") + } + NativeSortMergeJoinExec(left, right, leftKeys, rightKeys, joinType, skewJoin) + } +} diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala new file mode 100644 index 000000000..c27f47a50 --- /dev/null +++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import org.apache.auron.sparkverEnableMembers + +@sparkverEnableMembers("3.5") +class AuronAdaptiveQueryExecSuite + extends org.apache.spark.sql.QueryTest + with BaseAuronSQLSuite + with org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper { + + import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} + import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, SparkPlan} + import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec} + import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec + import org.apache.spark.sql.execution.exchange.Exchange + import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates} + import org.apache.spark.sql.internal.SQLConf + import org.apache.spark.sql.test.SQLTestData.TestData + import testImplicits._ + + test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { + + spark.sparkContext + .parallelize((1 to 10).map(i => TestData(if (i > 4) 5 else i, i.toString)), 3) + .toDF("c1", "c2") + .createOrReplaceTempView("v") + + def checkPartitionNumber( + query: String, + skewedPartitionNumber: Int, + totalNumber: Int): Unit = { + val (_, adaptive) = runAdaptiveAndVerifyResult(query) + val read = collect(adaptive) { case read: AQEShuffleReadExec => + read + } + assert(read.size == 1) + assert( + read.head.partitionSpecs.count(_.isInstanceOf[PartialReducerPartitionSpec]) == + skewedPartitionNumber) + assert(read.head.partitionSpecs.size == totalNumber) + } + + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { + // partition size [0, 75, 45, 68, 34] + checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4) + checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 0, 3) + } + + // no skewed partition should be optimized + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10000") { + checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 0, 1) + } + } + } + } + + private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { + var finalPlanCnt = 0 + var hasMetricsEvent = false + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, _, sparkPlanInfo) => + if (sparkPlanInfo.simpleString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) { + finalPlanCnt += 1 + } + case _: SparkListenerSQLAdaptiveSQLMetricUpdates => + hasMetricsEvent = true + case _ => // ignore other events + } + } + } + spark.sparkContext.addSparkListener(listener) + + val dfAdaptive = sql(query) + val planBefore = dfAdaptive.queryExecution.executedPlan + assert(planBefore.toString.startsWith("AdaptiveSparkPlan isFinalPlan=false")) + val result = dfAdaptive.collect() + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = sql(query) + checkAnswer(df, result) + } + val planAfter = dfAdaptive.queryExecution.executedPlan + assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) + val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + + spark.sparkContext.listenerBus.waitUntilEmpty() + // AQE will post `SparkListenerSQLAdaptiveExecutionUpdate` twice in case of subqueries that + // exist out of query stages. + val expectedFinalPlanCnt = adaptivePlan.find(_.subqueries.nonEmpty).map(_ => 2).getOrElse(1) + assert(finalPlanCnt == expectedFinalPlanCnt) + spark.sparkContext.removeSparkListener(listener) + + val expectedMetrics = findInMemoryTable(planAfter).nonEmpty || + subqueriesAll(planAfter).nonEmpty + assert(hasMetricsEvent == expectedMetrics) + + val exchanges = adaptivePlan.collect { case e: Exchange => + e + } + assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") + (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) + } + + private def findInMemoryTable(plan: SparkPlan): Seq[InMemoryTableScanExec] = { + collect(plan) { + case c: InMemoryTableScanExec + if c.relation.cachedPlan.isInstanceOf[AdaptiveSparkPlanExec] => + c + } + } +} diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala new file mode 100644 index 000000000..2f7e57074 --- /dev/null +++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +import org.apache.auron.util.AuronTestUtils + +class AuronFunctionSuite + extends org.apache.spark.sql.QueryTest + with BaseAuronSQLSuite + with AdaptiveSparkPlanHelper { + + test("sum function with float input") { + if (AuronTestUtils.isSparkV31OrGreater) { + withTable("t1") { + sql("create table t1 using parquet as select 1.0f as c1") + val df = sql("select sum(c1) from t1") + checkAnswer(df, Seq(Row(1.0))) + } + } + } + + test("sha2 function") { + withTable("t1") { + sql("create table t1 using parquet as select 'spark' as c1, '3.x' as version") + val functions = + """ + |select + | sha2(concat(c1, version), 256) as sha0, + | sha2(concat(c1, version), 256) as sha256, + | sha2(concat(c1, version), 224) as sha224, + | sha2(concat(c1, version), 384) as sha384, + | sha2(concat(c1, version), 512) as sha512 + |from t1 + |""".stripMargin + val df = sql(functions) + checkAnswer( + df, + Seq( + Row( + "562d20689257f3f3a04ee9afb86d0ece2af106cf6c6e5e7d266043088ce5fbc0", + "562d20689257f3f3a04ee9afb86d0ece2af106cf6c6e5e7d266043088ce5fbc0", + "d0c8e9ccd5c7b3fdbacd2cfd6b4d65ca8489983b5e8c7c64cd77b634", + "77c1199808053619c29e9af2656e1ad2614772f6ea605d5757894d6aec2dfaf34ff6fd662def3b79e429e9ae5ecbfed1", + "c4e27d35517ca62243c1f322d7922dac175830be4668e8a1cf3befdcd287bb5b6f8c5f041c9d89e4609c8cfa242008c7c7133af1685f57bac9052c1212f1d089"))) + } + } + + test("spark hash function") { + withTable("t1") { + sql("create table t1 using parquet as select array(1, 2) as arr") + val functions = + """ + |select hash(arr) from t1 + |""".stripMargin + val df = sql(functions) + checkAnswer(df, Seq(Row(-222940379))) + } + } + + test("expm1 function") { + withTable("t1") { + sql("create table t1(c1 double) using parquet") + sql("insert into t1 values(0.0), (1.1), (2.2)") + val df = sql("select expm1(c1) from t1") + checkAnswer(df, Seq(Row(0.0), Row(2.0041660239464334), Row(8.025013499434122))) + } + } + + test("factorial function") { + withTable("t1") { + sql("create table t1(c1 int) using parquet") + sql("insert into t1 values(5)") + val df = sql("select factorial(c1) from t1") + checkAnswer(df, Seq(Row(120))) + } + } + + test("hex function") { + withTable("t1") { + sql("create table t1(c1 int, c2 string) using parquet") + sql("insert into t1 values(17, 'Spark SQL')") + val df = sql("select hex(c1), hex(c2) from t1") + checkAnswer(df, Seq(Row("11", "537061726B2053514C"))) + } + } + + test("stddev_samp function with UDAF fallback") { + withSQLConf("spark.auron.udafFallback.enable" -> "true") { + withTable("t1") { + sql("create table t1(c1 double) using parquet") + sql("insert into t1 values(10.0), (20.0), (30.0), (31.0), (null)") + val df = sql("select stddev_samp(c1) from t1") + checkAnswer(df, Seq(Row(9.844626283748239))) + } + } + } + + test("regexp_extract function with UDF failback") { + withTable("t1") { + sql("create table t1(c1 string) using parquet") + sql("insert into t1 values('Auron Spark SQL')") + val df = sql("select regexp_extract(c1, '^A(.*)L$', 1) from t1") + checkAnswer(df, Seq(Row("uron Spark SQ"))) + } + } +} diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronQuerySuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronQuerySuite.scala new file mode 100644 index 000000000..5b9787d36 --- /dev/null +++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronQuerySuite.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.Row + +import org.apache.auron.util.AuronTestUtils + +class AuronQuerySuite + extends org.apache.spark.sql.QueryTest + with BaseAuronSQLSuite + with AuronSQLTestHelper { + import testImplicits._ + + test("test partition path has url encoded character") { + withTable("t1") { + sql( + "create table t1 using parquet PARTITIONED BY (part) as select 1 as c1, 2 as c2, 'test test' as part") + val df = sql("select * from t1") + checkAnswer(df, Seq(Row(1, 2, "test test"))) + } + } + + test("empty output in bnlj") { + withTable("t1", "t2") { + sql("create table t1 using parquet as select 1 as c1, 2 as c2") + sql("create table t2 using parquet as select 1 as c1, 3 as c3") + val df = sql("select 1 from t1 left join t2") + checkAnswer(df, Seq(Row(1))) + } + } + + test("test filter with year function") { + withTable("t1") { + sql("create table t1 using parquet as select '2024-12-18' as event_time") + checkAnswer( + sql(""" + |select year, count(*) + |from (select event_time, year(event_time) as year from t1) t + |where year <= 2024 + |group by year + |""".stripMargin), + Seq(Row(2024, 1))) + } + } + + test("test select multiple spark ext functions with the same signature") { + withTable("t1") { + sql("create table t1 using parquet as select '2024-12-18' as event_time") + checkAnswer(sql("select year(event_time), month(event_time) from t1"), Seq(Row(2024, 12))) + } + } + + test("test parquet/orc format table with complex data type") { + def createTableStatement(format: String): String = { + s"""create table test_with_complex_type( + |id bigint comment 'pk', + |m map comment 'test read map type', + |l array comment 'test read list type', + |s string comment 'string type' + |) USING $format + |""".stripMargin + } + Seq("parquet", "orc").foreach(format => + withTable("test_with_complex_type") { + sql(createTableStatement(format)) + sql( + "insert into test_with_complex_type select 1 as id, map('zero', '0', 'one', '1') as m, array('test','auron') as l, 'auron' as s") + checkAnswer( + sql("select id,l,m from test_with_complex_type"), + Seq(Row(1, ArrayBuffer("test", "auron"), Map("one" -> "1", "zero" -> "0")))) + }) + } + + test("binary type in range partitioning") { + withTable("t1", "t2") { + sql("create table t1(c1 binary, c2 int) using parquet") + sql("insert into t1 values (cast('test1' as binary), 1), (cast('test2' as binary), 2)") + val df = sql("select c2 from t1 order by c1") + checkAnswer(df, Seq(Row(1), Row(2))) + } + } + + test("repartition over MapType") { + withTable("t_map") { + sql("create table t_map using parquet as select map('a', '1', 'b', '2') as data_map") + val df = sql("SELECT /*+ repartition(10) */ data_map FROM t_map") + checkAnswer(df, Seq(Row(Map("a" -> "1", "b" -> "2")))) + } + } + + test("repartition over MapType with ArrayType") { + withTable("t_map_struct") { + sql( + "create table t_map_struct using parquet as select named_struct('m', map('x', '1')) as data_struct") + val df = sql("SELECT /*+ repartition(10) */ data_struct FROM t_map_struct") + checkAnswer(df, Seq(Row(Row(Map("x" -> "1"))))) + } + } + + test("repartition over ArrayType with MapType") { + withTable("t_array_map") { + sql(""" + |create table t_array_map using parquet as + |select array(map('k1', 1, 'k2', 2), map('k3', 3)) as array_of_map + |""".stripMargin) + val df = sql("SELECT /*+ repartition(10) */ array_of_map FROM t_array_map") + checkAnswer(df, Seq(Row(Seq(Map("k1" -> 1, "k2" -> 2), Map("k3" -> 3))))) + } + } + + test("repartition over StructType with MapType") { + withTable("t_struct_map") { + sql(""" + |create table t_struct_map using parquet as + |select named_struct('id', 101, 'metrics', map('ctr', 0.123d, 'cvr', 0.045d)) as user_metrics + |""".stripMargin) + val df = sql("SELECT /*+ repartition(10) */ user_metrics FROM t_struct_map") + checkAnswer(df, Seq(Row(Row(101, Map("ctr" -> 0.123, "cvr" -> 0.045))))) + } + } + + test("repartition over MapType with StructType") { + withTable("t_map_struct_value") { + sql(""" + |create table t_map_struct_value using parquet as + |select map( + | 'item1', named_struct('count', 3, 'score', 4.5d), + | 'item2', named_struct('count', 7, 'score', 9.1d) + |) as map_struct_value + |""".stripMargin) + val df = sql("SELECT /*+ repartition(10) */ map_struct_value FROM t_map_struct_value") + checkAnswer(df, Seq(Row(Map("item1" -> Row(3, 4.5), "item2" -> Row(7, 9.1))))) + } + } + + test("repartition over nested MapType") { + withTable("t_nested_map") { + sql(""" + |create table t_nested_map using parquet as + |select map( + | 'outer1', map('inner1', 10, 'inner2', 20), + | 'outer2', map('inner3', 30) + |) as nested_map + |""".stripMargin) + val df = sql("SELECT /*+ repartition(10) */ nested_map FROM t_nested_map") + checkAnswer( + df, + Seq(Row( + Map("outer1" -> Map("inner1" -> 10, "inner2" -> 20), "outer2" -> Map("inner3" -> 30))))) + } + } + + test("repartition over ArrayType of StructType with MapType") { + withTable("t_array_struct_map") { + sql(""" + |create table t_array_struct_map using parquet as + |select array( + | named_struct('name', 'user1', 'features', map('f1', 1.0d, 'f2', 2.0d)), + | named_struct('name', 'user2', 'features', map('f3', 3.5d)) + |) as user_feature_array + |""".stripMargin) + val df = sql("SELECT /*+ repartition(10) */ user_feature_array FROM t_array_struct_map") + checkAnswer( + df, + Seq( + Row( + Seq(Row("user1", Map("f1" -> 1.0f, "f2" -> 2.0f)), Row("user2", Map("f3" -> 3.5f)))))) + } + } + + test("log function with negative input") { + withTable("t1") { + sql("create table t1 using parquet as select -1 as c1") + val df = sql("select ln(c1) from t1") + checkAnswer(df, Seq(Row(null))) + } + } + + test("floor function with long input") { + withTable("t1") { + sql("create table t1 using parquet as select 1L as c1, 2.2 as c2") + val df = sql("select floor(c1), floor(c2) from t1") + checkAnswer(df, Seq(Row(1, 2))) + } + } + + test("SPARK-32234 read ORC table with column names all starting with '_col'") { + withTable("test_hive_orc_impl") { + spark.sql(s""" + | CREATE TABLE test_hive_orc_impl + | (_col1 INT, _col2 STRING, _col3 INT) + | USING ORC + """.stripMargin) + spark.sql(s""" + | INSERT INTO + | test_hive_orc_impl + | VALUES(9, '12', 2020) + """.stripMargin) + + val df = spark.sql("SELECT _col2 FROM test_hive_orc_impl") + checkAnswer(df, Row("12")) + } + } + + test("SPARK-32864: Support ORC forced positional evolution") { + if (AuronTestUtils.isSparkV32OrGreater) { + Seq(true, false).foreach { forcePositionalEvolution => + withEnvConf( + AuronConf.ORC_FORCE_POSITIONAL_EVOLUTION.key -> forcePositionalEvolution.toString) { + withTempPath { f => + val path = f.getCanonicalPath + Seq[(Integer, Integer)]((1, 2), (3, 4), (5, 6), (null, null)) + .toDF("c1", "c2") + .write + .orc(path) + val correctAnswer = Seq(Row(1, 2), Row(3, 4), Row(5, 6), Row(null, null)) + checkAnswer(spark.read.orc(path), correctAnswer) + + withTable("t") { + sql(s"CREATE EXTERNAL TABLE t(c3 INT, c2 INT) USING ORC LOCATION '$path'") + + val expected = if (forcePositionalEvolution) { + correctAnswer + } else { + Seq(Row(null, 2), Row(null, 4), Row(null, 6), Row(null, null)) + } + + checkAnswer(spark.table("t"), expected) + } + } + } + } + } + } + + test("SPARK-32864: Support ORC forced positional evolution with partitioned table") { + if (AuronTestUtils.isSparkV32OrGreater) { + Seq(true, false).foreach { forcePositionalEvolution => + withEnvConf( + AuronConf.ORC_FORCE_POSITIONAL_EVOLUTION.key -> forcePositionalEvolution.toString) { + withTempPath { f => + val path = f.getCanonicalPath + Seq[(Integer, Integer, Integer)]((1, 2, 1), (3, 4, 2), (5, 6, 3), (null, null, 4)) + .toDF("c1", "c2", "p") + .write + .partitionBy("p") + .orc(path) + val correctAnswer = Seq(Row(1, 2, 1), Row(3, 4, 2), Row(5, 6, 3), Row(null, null, 4)) + checkAnswer(spark.read.orc(path), correctAnswer) + + withTable("t") { + sql(s""" + |CREATE TABLE t(c3 INT, c2 INT) + |USING ORC + |PARTITIONED BY (p int) + |LOCATION '$path' + |""".stripMargin) + sql("MSCK REPAIR TABLE t") + val expected = if (forcePositionalEvolution) { + correctAnswer + } else { + Seq(Row(null, 2, 1), Row(null, 4, 2), Row(null, 6, 3), Row(null, null, 4)) + } + + checkAnswer(spark.table("t"), expected) + } + } + } + } + } + } +} diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronSQLTestHelper.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronSQLTestHelper.scala new file mode 100644 index 000000000..f37474a06 --- /dev/null +++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronSQLTestHelper.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import org.apache.spark.SparkEnv + +trait AuronSQLTestHelper { + def withEnvConf(pairs: (String, String)*)(f: => Unit): Unit = { + val env = SparkEnv.get + if (env == null) { + throw new IllegalStateException("SparkEnv is not initialized") + } + val conf = env.conf + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.get(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + conf.set(k, v) + } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.set(key, value) + case (key, None) => conf.remove(key) + } + } + } +} diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/BaseAuronSQLSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/BaseAuronSQLSuite.scala new file mode 100644 index 000000000..6e30c960f --- /dev/null +++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/BaseAuronSQLSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import org.apache.spark.SparkConf +import org.apache.spark.sql.test.SharedSparkSession + +trait BaseAuronSQLSuite extends SharedSparkSession { + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.sql.extensions", "org.apache.spark.sql.auron.AuronSparkSessionExtension") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.execution.auron.shuffle.AuronShuffleManager") + .set("spark.memory.offHeap.enabled", "false") + .set("spark.auron.enable", "true") + } + +} diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala new file mode 100644 index 000000000..d5a1a60c1 --- /dev/null +++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.auron + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType} + +import org.apache.auron.protobuf.ScalarFunction + +class NativeConvertersSuite extends QueryTest with BaseAuronSQLSuite with AuronSQLTestHelper { + + private def assertTrimmedCast(rawValue: String, targetType: DataType): Unit = { + val expr = Cast(Literal.create(rawValue, StringType), targetType) + val nativeExpr = NativeConverters.convertExpr(expr) + + assert(nativeExpr.hasTryCast) + val childExpr = nativeExpr.getTryCast.getExpr + assert(childExpr.hasScalarFunction) + val scalarFn = childExpr.getScalarFunction + assert(scalarFn.getFun == ScalarFunction.Trim) + assert(scalarFn.getArgsCount == 1 && scalarFn.getArgs(0).hasLiteral) + } + + private def assertNonTrimmedCast(rawValue: String, targetType: DataType): Unit = { + val expr = Cast(Literal.create(rawValue, StringType), targetType) + val nativeExpr = NativeConverters.convertExpr(expr) + + assert(nativeExpr.hasTryCast) + val childExpr = nativeExpr.getTryCast.getExpr + assert(!childExpr.hasScalarFunction) + assert(childExpr.hasLiteral) + } + + test("cast from string to numeric adds trim wrapper before native cast when enabled") { + withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") { + assertTrimmedCast(" 42 ", IntegerType) + } + } + + test("cast from string to boolean adds trim wrapper before native cast when enabled") { + withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") { + assertTrimmedCast(" true ", BooleanType) + } + } + + test("cast trim disabled via auron conf") { + withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") { + assertNonTrimmedCast(" 42 ", IntegerType) + } + } + + test("cast trim disabled via auron conf for boolean cast") { + withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") { + assertNonTrimmedCast(" true ", BooleanType) + } + } + + test("cast with non-string child remains unchanged") { + val expr = Cast(Literal(1.5), IntegerType) + val nativeExpr = NativeConverters.convertExpr(expr) + + assert(nativeExpr.hasTryCast) + val childExpr = nativeExpr.getTryCast.getExpr + assert(!childExpr.hasScalarFunction) + assert(childExpr.hasLiteral) + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala index dc2f8c727..fad4beac7 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala @@ -331,13 +331,13 @@ object AuronConverters extends Logging { getShuffleOrigin(exec)) } - @sparkver(" 3.2 / 3.3 / 3.4 / 3.5") + @sparkver(" 3.2 / 3.3 / 3.4 / 3.5 / 4.0") def getIsSkewJoinFromSHJ(exec: ShuffledHashJoinExec): Boolean = exec.isSkewJoin @sparkver("3.0 / 3.1") def getIsSkewJoinFromSHJ(exec: ShuffledHashJoinExec): Boolean = false - @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5 / 4.0") def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin) @sparkver("3.0") @@ -1016,7 +1016,7 @@ object AuronConverters extends Logging { rddPartitioner = None, rddDependencies = Nil, false, - (_partition, _taskContext) => { + (_, _) => { val nativeEmptyExec = EmptyPartitionsExecNode .newBuilder() .setNumPartitions(outputPartitioning.numPartitions) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala index fbac6a929..375d3c51e 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala @@ -27,7 +27,6 @@ import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute @@ -240,7 +239,7 @@ abstract class Shims { def convertMoreSparkPlan(exec: SparkPlan): Option[SparkPlan] - def getSqlContext(sparkPlan: SparkPlan): SQLContext + def getSqlContext(sparkPlan: SparkPlan): org.apache.spark.sql.SQLContext def createNativeExprWrapper( nativeExpr: pb.PhysicalExprNode, diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala index 988875adc..6a66d964e 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.ShortType import org.apache.spark.sql.types.TimestampType -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} class AuronColumnarArray(data: AuronColumnVector, offset: Int, length: Int) extends ArrayData { override def numElements: Int = length @@ -154,4 +153,8 @@ class AuronColumnarArray(data: AuronColumnVector, offset: Int, length: Int) exte override def setNullAt(ordinal: Int): Unit = { throw new UnsupportedOperationException } + + override def getVariant(i: Int): VariantVal = { + throw new UnsupportedOperationException + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala index 6b24e0f55..b6c032a0d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.ShortType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} class AuronColumnarBatchRow(columns: Array[AuronColumnVector], var rowId: Int = 0) extends InternalRow { @@ -133,4 +132,8 @@ class AuronColumnarBatchRow(columns: Array[AuronColumnVector], var rowId: Int = override def setNullAt(ordinal: Int): Unit = { throw new UnsupportedOperationException } + + override def getVariant(i: Int): VariantVal = { + throw new UnsupportedOperationException + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala index 34e7a717d..255eacc54 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.ShortType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} class AuronColumnarStruct(data: AuronColumnVector, rowId: Int) extends InternalRow { override def numFields: Int = data.dataType.asInstanceOf[StructType].size @@ -143,4 +142,8 @@ class AuronColumnarStruct(data: AuronColumnVector, rowId: Int) extends InternalR override def setNullAt(ordinal: Int): Unit = { throw new UnsupportedOperationException } + + override def getVariant(i: Int): VariantVal = { + throw new UnsupportedOperationException + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala index 7dafd564a..a5ebb50cc 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala @@ -138,8 +138,10 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi } def doExecuteBroadcastNative[T](): broadcast.Broadcast[T] = { - val conf = SparkSession.getActiveSession.map(_.sqlContext.conf).orNull - val timeout: Long = conf.broadcastTimeout + SparkSession.getActiveSession.map(_.conf).orNull + val timeout: Long = SparkSession.getActiveSession + .map(s => s.conf.get("spark.sql.broadcastTimeout", "300").toLong) + .getOrElse(300L) try { relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] } catch { @@ -260,7 +262,7 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi @transient lazy val relationFuture: Future[Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[Broadcast[Any]]( - Shims.get.getSqlContext(this).sparkSession, + this.session.sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { try { sparkContext.setJobGroup( From 99989919a3487e23b5ce9126c785c4ba2d0a2b02 Mon Sep 17 00:00:00 2001 From: slfan1989 Date: Sun, 5 Oct 2025 23:28:51 +0800 Subject: [PATCH 2/5] [AURON#1404] Support for Spark 4.0.1 Compatibility in Auron. --- pom.xml | 61 ++++++++++++++++--- .../apache/spark/sql/auron/ShimsImpl.scala | 6 +- .../plan/NativeShuffleExchangeExec.scala | 35 ++++------- .../org/apache/spark/sql/auron/Shims.scala | 3 +- .../auron/columnar/AuronColumnarArray.scala | 8 ++- .../columnar/AuronColumnarBatchRow.scala | 8 ++- .../auron/columnar/AuronColumnarStruct.scala | 8 ++- .../plan/NativeBroadcastExchangeBase.scala | 46 ++++++++++++-- .../shuffle/AuronRssShuffleWriterBase.scala | 2 +- 9 files changed, 127 insertions(+), 50 deletions(-) diff --git a/pom.xml b/pom.xml index ca89d262a..ef838577c 100644 --- a/pom.xml +++ b/pom.xml @@ -68,7 +68,6 @@ 3.0.0 2.1.1 - 4.9.9 -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED @@ -379,7 +378,7 @@ com.diffplug.spotless spotless-maven-plugin - 2.45.0 + ${spotless.plugin.version} @@ -516,7 +515,6 @@ spark-3.0 spark-extension-shims-spark3 3.0.8 - 3.0.0 3.0.3 @@ -527,7 +525,6 @@ spark-3.1 spark-extension-shims-spark3 3.2.9 - 3.0.0 3.1.3 @@ -538,7 +535,6 @@ spark-3.2 spark-extension-shims-spark3 3.2.9 - 3.0.0 3.2.4 @@ -549,7 +545,6 @@ spark-3.3 spark-extension-shims-spark3 3.2.9 - 3.0.0 3.3.4 @@ -560,7 +555,6 @@ spark-3.4 spark-extension-shims-spark3 3.2.9 - 3.0.0 3.4.4 @@ -571,7 +565,6 @@ spark-3.5 spark-extension-shims-spark3 3.2.9 - 3.0.0 3.5.7 @@ -585,6 +578,36 @@ 3.9.9 4.0.1 + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.4.1 + + + enforce-java-scala-version + + enforce + + + + + [17,) + Spark 4.0 requires JDK 17 or higher to compile! + + + scalaVersion + ^2.13.* + Spark 4.0 requires Scala 2.13 and is not compatible with Scala 2.12! + + + + + + + + @@ -594,6 +617,9 @@ 8 + 2.30.0 + 4.8.8 + 3.0.0 @@ -604,6 +630,9 @@ 11 + 2.30.0 + 4.8.8 + 3.0.0 @@ -614,6 +643,22 @@ 17 + 2.45.0 + 4.9.9 + 3.9.9 + + + + + jdk-21 + + 21 + + + 21 + 2.45.0 + 4.9.9 + 3.9.9 diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala index 5f65c5533..a10d0cdae 100644 --- a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.shuffle.ShuffleWriteMetricsReporter -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.auron.AuronConverters.ForceNativeExecutionWrapperBase import org.apache.spark.sql.auron.NativeConverters.NativeExprWrapperBase import org.apache.spark.sql.catalyst.InternalRow @@ -626,9 +626,7 @@ class ShimsImpl extends Shims with Logging { } @sparkver("4.0") - override def getSqlContext( - sparkPlan: org.apache.spark.sql.execution.SparkPlan): org.apache.spark.sql.SQLContext = - sparkPlan.session.sqlContext + override def getSqlContext(sparkPlan: SparkPlan): SQLContext = ??? override def createNativeExprWrapper( nativeExpr: pb.PhysicalExprNode, diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala index e6061adc0..3a7988aa5 100644 --- a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffleExchangeExec.scala @@ -129,14 +129,12 @@ case class NativeShuffleExchangeExec( dep: ShuffleDependency[_, _, _], mapId: Long, mapIndex: Int, - context: TaskContext): MapStatus = super.write(inputs, dep, mapId, mapIndex, context) + context: TaskContext): MapStatus = { - def write( - rdd: RDD[_], - dep: ShuffleDependency[_, _, _], - mapId: Long, - context: TaskContext, - partition: Partition): MapStatus = { + // [SPARK-44605][CORE] Refined the internal ShuffleWriteProcessor API. + // Due to the restructuring of the write method in the API, we optimized and refactored the original Partition. + val rdd = dep.rdd + val partition = rdd.partitions(mapIndex) val writer = SparkEnv.get.shuffleManager.getWriter( dep.shuffleHandle, @@ -167,31 +165,20 @@ case class NativeShuffleExchangeExec( } } - // for databricks testing - val causedBroadcastJoinBuildOOM = false - - @sparkver("3.5 / 4.0") + @sparkver("4.0") override def advisoryPartitionSize: Option[Long] = None - // If users specify the num partitions via APIs like `repartition`, we shouldn't change it. - // For `SinglePartition`, it requires exactly one partition and we can't change it either. - @sparkver("3.0") - override def canChangeNumPartitions: Boolean = - outputPartitioning != SinglePartition - - @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5 / 4.0") + @sparkver("4.0") override def shuffleOrigin = { import org.apache.spark.sql.execution.exchange.ShuffleOrigin; _shuffleOrigin.get.asInstanceOf[ShuffleOrigin] } - @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + @sparkver("4.0") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @sparkver("3.0 / 3.1") - override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = - copy(child = newChildren.head) - - override def shuffleId: Int = ??? + override def shuffleId: Int = { + shuffleDependency.shuffleId; + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala index 375d3c51e..fbac6a929 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala @@ -27,6 +27,7 @@ import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute @@ -239,7 +240,7 @@ abstract class Shims { def convertMoreSparkPlan(exec: SparkPlan): Option[SparkPlan] - def getSqlContext(sparkPlan: SparkPlan): org.apache.spark.sql.SQLContext + def getSqlContext(sparkPlan: SparkPlan): SQLContext def createNativeExprWrapper( nativeExpr: pb.PhysicalExprNode, diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala index 6a66d964e..67a813242 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarArray.scala @@ -31,7 +31,10 @@ import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.ShortType import org.apache.spark.sql.types.TimestampType -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.auron.sparkver class AuronColumnarArray(data: AuronColumnVector, offset: Int, length: Int) extends ArrayData { override def numElements: Int = length @@ -154,7 +157,8 @@ class AuronColumnarArray(data: AuronColumnVector, offset: Int, length: Int) exte throw new UnsupportedOperationException } - override def getVariant(i: Int): VariantVal = { + @sparkver("4.0") + override def getVariant(i: Int): org.apache.spark.unsafe.types.VariantVal = { throw new UnsupportedOperationException } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala index b6c032a0d..65ee91610 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarBatchRow.scala @@ -34,7 +34,10 @@ import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.ShortType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.auron.sparkver class AuronColumnarBatchRow(columns: Array[AuronColumnVector], var rowId: Int = 0) extends InternalRow { @@ -133,7 +136,8 @@ class AuronColumnarBatchRow(columns: Array[AuronColumnVector], var rowId: Int = throw new UnsupportedOperationException } - override def getVariant(i: Int): VariantVal = { + @sparkver("4.0") + override def getVariant(i: Int): org.apache.spark.unsafe.types.VariantVal = { throw new UnsupportedOperationException } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala index 255eacc54..cfb5debeb 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/columnar/AuronColumnarStruct.scala @@ -34,7 +34,10 @@ import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.ShortType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.auron.sparkver class AuronColumnarStruct(data: AuronColumnVector, rowId: Int) extends InternalRow { override def numFields: Int = data.dataType.asInstanceOf[StructType].size @@ -143,7 +146,8 @@ class AuronColumnarStruct(data: AuronColumnVector, rowId: Int) extends InternalR throw new UnsupportedOperationException } - override def getVariant(i: Int): VariantVal = { + @sparkver("4.0") + override def getVariant(i: Int): org.apache.spark.unsafe.types.VariantVal = { throw new UnsupportedOperationException } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala index a5ebb50cc..408aa68e5 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastExchangeBase.scala @@ -65,7 +65,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.BinaryType -import org.apache.auron.{protobuf => pb} +import org.apache.auron.{protobuf => pb, sparkver} abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val child: SparkPlan) extends BroadcastExchangeLike @@ -137,11 +137,21 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi dummyBroadcasted.asInstanceOf[Broadcast[T]] } - def doExecuteBroadcastNative[T](): broadcast.Broadcast[T] = { - SparkSession.getActiveSession.map(_.conf).orNull - val timeout: Long = SparkSession.getActiveSession - .map(s => s.conf.get("spark.sql.broadcastTimeout", "300").toLong) + @sparkver("3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5") + def getBroadcastTimeout: Long = { + val conf = SparkSession.getActiveSession.map(_.sqlContext.conf).orNull + conf.broadcastTimeout + } + + @sparkver("4.0") + def getBroadcastTimeout: Long = { + SparkSession.getActiveSession + .map(_.conf.get("spark.sql.broadcastTimeout").toLong) .getOrElse(300L) + } + + def doExecuteBroadcastNative[T](): broadcast.Broadcast[T] = { + val timeout: Long = getBroadcastTimeout try { relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] } catch { @@ -260,7 +270,10 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi lazy val relationFuturePromise: Promise[Broadcast[Any]] = Promise[Broadcast[Any]]() @transient - lazy val relationFuture: Future[Broadcast[Any]] = { + lazy val relationFuture: Future[Broadcast[Any]] = getRelationFuture + + @sparkver("4.0") + private def getRelationFuture = { SQLExecution.withThreadLocalCaptured[Broadcast[Any]]( this.session.sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { @@ -280,6 +293,27 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi } } + @sparkver("3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5") + private def getRelationFuture = { + SQLExecution.withThreadLocalCaptured[Broadcast[Any]]( + Shims.get.getSqlContext(this).sparkSession, + BroadcastExchangeExec.executionContext) { + try { + sparkContext.setJobGroup( + getRunId.toString, + s"native broadcast exchange (runId $getRunId)", + interruptOnCancel = true) + val broadcasted = sparkContext.broadcast(collectNative().asInstanceOf[Any]) + relationFuturePromise.trySuccess(broadcasted) + broadcasted + } catch { + case e: Throwable => + relationFuturePromise.tryFailure(e) + throw e + } + } + } + override protected def doCanonicalize(): SparkPlan = Shims.get.createNativeBroadcastExchangeExec(mode.canonicalized, child.canonicalized) } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala index 19ba9bdfb..8717d0bea 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala @@ -77,7 +77,7 @@ abstract class AuronRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsRepor def rssStop(success: Boolean): Option[MapStatus] - @sparkver("3.2 / 3.3 / 3.4 / 3.5") + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") override def getPartitionLengths(): Array[Long] = rpw.getPartitionLengthMap override def write(records: Iterator[Product2[K, V]]): Unit = { From ec39f8c1d6c69b34e386b60bf0c4b554d26c14f8 Mon Sep 17 00:00:00 2001 From: slfan1989 Date: Sun, 5 Oct 2025 23:36:04 +0800 Subject: [PATCH 3/5] [AURON#1404] Support for Spark 4.0.1 Compatibility in Auron. --- spark-extension-shims-spark4/pom.xml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/spark-extension-shims-spark4/pom.xml b/spark-extension-shims-spark4/pom.xml index 2ab23fac8..d70209fc9 100644 --- a/spark-extension-shims-spark4/pom.xml +++ b/spark-extension-shims-spark4/pom.xml @@ -1,4 +1,20 @@ + 4.0.0 From 1147f492532961e0d0234c51004f2c465dc47c3f Mon Sep 17 00:00:00 2001 From: slfan1989 <55643692+slfan1989@users.noreply.github.com> Date: Wed, 8 Oct 2025 16:48:49 +0800 Subject: [PATCH 4/5] Update spark-extension-shims-spark4/pom.xml Co-authored-by: cxzl25 <3898450+cxzl25@users.noreply.github.com> --- spark-extension-shims-spark4/pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/spark-extension-shims-spark4/pom.xml b/spark-extension-shims-spark4/pom.xml index d70209fc9..61bde2cbf 100644 --- a/spark-extension-shims-spark4/pom.xml +++ b/spark-extension-shims-spark4/pom.xml @@ -27,7 +27,6 @@ spark-extension-shims-spark4_${scalaVersion} jar - http://maven.apache.org UTF-8 From d149380e9930b57e3469122fcfb611780beaa6b2 Mon Sep 17 00:00:00 2001 From: slfan1989 Date: Fri, 10 Oct 2025 12:09:51 +0800 Subject: [PATCH 5/5] [AURON#1404] Support for Spark 4.0.1 Compatibility in Auron. --- auron-build.sh | 4 ++-- spark-extension-shims-spark4/pom.xml | 1 - .../spark/sql/execution/auron/plan/NativeSortExec.scala | 6 +----- .../sql/execution/auron/plan/NativeTakeOrderedExec.scala | 6 +----- .../spark/sql/execution/auron/plan/NativeUnionExec.scala | 6 +----- .../spark/sql/execution/auron/plan/NativeWindowExec.scala | 6 +----- .../spark/sql/auron/AuronAdaptiveQueryExecSuite.scala | 2 +- 7 files changed, 7 insertions(+), 24 deletions(-) diff --git a/auron-build.sh b/auron-build.sh index 3e7b321b6..70ed2aed2 100755 --- a/auron-build.sh +++ b/auron-build.sh @@ -102,10 +102,10 @@ while [[ $# -gt 0 ]]; do SPARK_VER="$2" if [ "$SPARK_VER" = "3.0" ] || [ "$SPARK_VER" = "3.1" ] \ || [ "$SPARK_VER" = "3.2" ] || [ "$SPARK_VER" = "3.3" ] \ - || [ "$SPARK_VER" = "3.4" ] || [ "$SPARK_VER" = "3.5" ]; then + || [ "$SPARK_VER" = "3.4" ] || [ "$SPARK_VER" = "3.5" ] || [ "$SPARK_VER" = "4.0" ]; then echo "Building for Spark $SPARK_VER" else - echo "ERROR: Invalid Spark version: $SPARK_VER. The currently supported versions are: 3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5." + echo "ERROR: Invalid Spark version: $SPARK_VER. The currently supported versions are: 3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5 / 4.0." exit 1 fi shift 2 diff --git a/spark-extension-shims-spark4/pom.xml b/spark-extension-shims-spark4/pom.xml index 61bde2cbf..633cf7cf6 100644 --- a/spark-extension-shims-spark4/pom.xml +++ b/spark-extension-shims-spark4/pom.xml @@ -27,7 +27,6 @@ spark-extension-shims-spark4_${scalaVersion} jar - UTF-8 diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala index 26809d116..b532aac17 100644 --- a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeSortExec.scala @@ -27,11 +27,7 @@ case class NativeSortExec( override val child: SparkPlan) extends NativeSortBase(sortOrder, global, child) { - @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + @sparkver("4.0") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - - @sparkver("3.0 / 3.1") - override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = - copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala index 3b820e88b..395c1527c 100644 --- a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala @@ -27,11 +27,7 @@ case class NativeTakeOrderedExec( override val child: SparkPlan) extends NativeTakeOrderedBase(limit, sortOrder, child) { - @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + @sparkver("4.0") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - - @sparkver("3.0 / 3.1") - override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = - copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala index 130697b78..30e092cd5 100644 --- a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeUnionExec.scala @@ -26,11 +26,7 @@ case class NativeUnionExec( override val output: Seq[Attribute]) extends NativeUnionBase(children, output) { - @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + @sparkver("4.0") override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = copy(children = newChildren) - - @sparkver("3.0 / 3.1") - override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = - copy(children = newChildren) } diff --git a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala index d14647c0b..c26a00f84 100644 --- a/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala +++ b/spark-extension-shims-spark4/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowExec.scala @@ -31,11 +31,7 @@ case class NativeWindowExec( override val child: SparkPlan) extends NativeWindowBase(windowExpression, partitionSpec, orderSpec, groupLimit, child) { - @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") + @sparkver("4.0") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - - @sparkver("3.0 / 3.1") - override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = - copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala index c27f47a50..96e39b3e1 100644 --- a/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala +++ b/spark-extension-shims-spark4/src/test/scala/org/apache/spark/sql/auron/AuronAdaptiveQueryExecSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.auron import org.apache.auron.sparkverEnableMembers -@sparkverEnableMembers("3.5") +@sparkverEnableMembers("4.0") class AuronAdaptiveQueryExecSuite extends org.apache.spark.sql.QueryTest with BaseAuronSQLSuite