diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala index 684fbd36f1ac..cf1c1d38d6c4 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala @@ -25,9 +25,11 @@ import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Join, Statistics} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.RDDScanTransformer +import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -123,7 +125,7 @@ case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil { } } -object OffloadJoin { +object OffloadJoin extends Logging { def getShjBuildSide(shj: ShuffledHashJoinExec): BuildSide = { val leftBuildable = BackendsApiManager.getSettings.supportHashBuildJoinTypeOnLeft(shj.joinType) @@ -144,33 +146,102 @@ object OffloadJoin { // User disabled build side re-optimization. Return original build side from vanilla Spark. return shj.buildSide } - shj.logicalLink - .flatMap { - case join: Join => Some(getOptimalBuildSide(join)) - case _ => None - } - .getOrElse { - // Some shj operators generated in certain Spark tests such as OuterJoinSuite, - // could possibly have no logical link set. - shj.buildSide - } + + getOptimalBuildSide(shj) + // Some shj operators generated in certain Spark tests such as OuterJoinSuite, + // could possibly have no logical link set. + .getOrElse(shj.buildSide) } - def getOptimalBuildSide(join: Join): BuildSide = { - val leftSize = join.left.stats.sizeInBytes - val rightSize = join.right.stats.sizeInBytes - val leftRowCount = join.left.stats.rowCount - val rightRowCount = join.right.stats.rowCount - if (leftSize == rightSize && rightRowCount.isDefined && leftRowCount.isDefined) { - if (rightRowCount.get <= leftRowCount.get) { - return BuildRight - } - return BuildLeft + /** + * Determine the optimal build side by picking the smaller side. + * + * Runtime stats from completed AQE QueryStageExec nodes are preferred. If runtime stats are + * unavailable, fallback to plan stats via logicalLink. + */ + def getOptimalBuildSide(join: BaseJoinExec): Option[BuildSide] = { + val logicalJoin = join.logicalLink.collect { case j: Join => j } + val leftRuntimeStats = findRuntimeStats(join.left) + val rightRuntimeStats = findRuntimeStats(join.right) + + val leftSize = + leftRuntimeStats.map(_.sizeInBytes).orElse(logicalJoin.map(_.left.stats.sizeInBytes)) + val rightSize = + rightRuntimeStats.map(_.sizeInBytes).orElse(logicalJoin.map(_.right.stats.sizeInBytes)) + val leftRowCount = + leftRuntimeStats.flatMap(_.rowCount).orElse(logicalJoin.flatMap(_.left.stats.rowCount)) + val rightRowCount = + rightRuntimeStats.flatMap(_.rowCount).orElse(logicalJoin.flatMap(_.right.stats.rowCount)) + + (leftSize, rightSize) match { + case (Some(l), Some(r)) => + logDebug(s"Build side decision: leftSize=$l, rightSize=$r") + if (l == r && rightRowCount.isDefined && leftRowCount.isDefined) { + if (rightRowCount.get <= leftRowCount.get) { + Some(BuildRight) + } else Some(BuildLeft) + } else if (r <= l) { + Some(BuildRight) + } else { + Some(BuildLeft) + } + case _ => None } - if (rightSize <= leftSize) { - return BuildRight + } + + /** + * Recursively find runtime Statistics from the physical plan. + * - QueryStageExec: directly returns materialized runtime stats + * - Single-child wrappers (e.g. AQEShuffleRead): pass through to the child + * - Join operators: estimate output from children's runtime stats + */ + private def findRuntimeStats(plan: SparkPlan): Option[Statistics] = plan match { + case stage: QueryStageExec => stage.computeStats + case p if p.children.length == 1 => findRuntimeStats(p.children.head) + case j: BaseJoinExec => estimateJoinOutputStats(j) + case _ => None + } + + /** + * Estimate the output Statistics of an inline join from children's runtime stats. The outputRows + * estimation follows Spark's JoinEstimation. + */ + private def estimateJoinOutputStats(join: BaseJoinExec): Option[Statistics] = { + val leftStats = findRuntimeStats(join.left) + val rightStats = findRuntimeStats(join.right) + logDebug(s"Estimating ${join.getClass.getSimpleName}(${join.joinType})") + + (leftStats, rightStats) match { + case ( + Some(Statistics(lBytes, Some(lRows), _, _)), + Some(Statistics(rBytes, Some(rRows), _, _))) => + val lAvgRowSize = if (lRows > 0) lBytes.toDouble / lRows.toDouble else 0.0 + val rAvgRowSize = if (rRows > 0) rBytes.toDouble / rRows.toDouble else 0.0 + val outputRowSize = lAvgRowSize + rAvgRowSize + + val numInnerJoinedRows = + if (join.leftKeys.nonEmpty) lRows.min(rRows) else lRows * rRows + + val (estRows, estBytes) = join.joinType match { + case LeftSemi | LeftAnti => + (lRows, lBytes) + case LeftOuter => + val rows = lRows.max(numInnerJoinedRows) + (rows, BigInt((rows.toDouble * outputRowSize).toLong)) + case RightOuter => + val rows = rRows.max(numInnerJoinedRows) + (rows, BigInt((rows.toDouble * outputRowSize).toLong)) + case FullOuter => + val rows = lRows.max(numInnerJoinedRows) + + rRows.max(numInnerJoinedRows) - numInnerJoinedRows + (rows, BigInt((rows.toDouble * outputRowSize).toLong)) + case _ => + (numInnerJoinedRows, BigInt((numInnerJoinedRows.toDouble * outputRowSize).toLong)) + } + logDebug(s"Estimated join output: rows=$estRows, bytes=$estBytes") + Some(Statistics(estBytes, Some(estRows))) + case _ => None } - BuildLeft } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala index 6d456857d734..6594848b8060 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala @@ -20,7 +20,6 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.extension.columnar.offload.OffloadJoin import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper} -import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} @@ -45,17 +44,11 @@ object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper { if (!rightBuildable) { return Some(BuildLeft) } - val side = join.logicalLink - .flatMap { - case join: Join => Some(OffloadJoin.getOptimalBuildSide(join)) - case _ => None - } - .getOrElse { - // If smj has no logical link, or its logical link is not a join, - // then we always choose left as build side. - BuildLeft - } - Some(side) + OffloadJoin + .getOptimalBuildSide(join) + // If neither runtime stats nor logical link is available, + // then we always choose left as build side. + .orElse(Some(BuildLeft)) } override def rewrite(plan: SparkPlan): SparkPlan = plan match {