Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Join, Statistics}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.RDDScanTransformer
import org.apache.spark.sql.execution.adaptive.QueryStageExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
Expand Down Expand Up @@ -123,7 +125,7 @@ case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {
}
}

object OffloadJoin {
object OffloadJoin extends Logging {
def getShjBuildSide(shj: ShuffledHashJoinExec): BuildSide = {
val leftBuildable =
BackendsApiManager.getSettings.supportHashBuildJoinTypeOnLeft(shj.joinType)
Expand All @@ -144,33 +146,102 @@ object OffloadJoin {
// User disabled build side re-optimization. Return original build side from vanilla Spark.
return shj.buildSide
}
shj.logicalLink
.flatMap {
case join: Join => Some(getOptimalBuildSide(join))
case _ => None
}
.getOrElse {
// Some shj operators generated in certain Spark tests such as OuterJoinSuite,
// could possibly have no logical link set.
shj.buildSide
}

getOptimalBuildSide(shj)
// Some shj operators generated in certain Spark tests such as OuterJoinSuite,
// could possibly have no logical link set.
.getOrElse(shj.buildSide)
}

def getOptimalBuildSide(join: Join): BuildSide = {
val leftSize = join.left.stats.sizeInBytes
val rightSize = join.right.stats.sizeInBytes
val leftRowCount = join.left.stats.rowCount
val rightRowCount = join.right.stats.rowCount
if (leftSize == rightSize && rightRowCount.isDefined && leftRowCount.isDefined) {
if (rightRowCount.get <= leftRowCount.get) {
return BuildRight
}
return BuildLeft
/**
* Determine the optimal build side by picking the smaller side.
*
* Runtime stats from completed AQE QueryStageExec nodes are preferred. If runtime stats are
* unavailable, fallback to plan stats via logicalLink.
*/
def getOptimalBuildSide(join: BaseJoinExec): Option[BuildSide] = {
val logicalJoin = join.logicalLink.collect { case j: Join => j }
val leftRuntimeStats = findRuntimeStats(join.left)
val rightRuntimeStats = findRuntimeStats(join.right)

val leftSize =
leftRuntimeStats.map(_.sizeInBytes).orElse(logicalJoin.map(_.left.stats.sizeInBytes))
val rightSize =
rightRuntimeStats.map(_.sizeInBytes).orElse(logicalJoin.map(_.right.stats.sizeInBytes))
val leftRowCount =
leftRuntimeStats.flatMap(_.rowCount).orElse(logicalJoin.flatMap(_.left.stats.rowCount))
val rightRowCount =
rightRuntimeStats.flatMap(_.rowCount).orElse(logicalJoin.flatMap(_.right.stats.rowCount))

(leftSize, rightSize) match {
case (Some(l), Some(r)) =>
logDebug(s"Build side decision: leftSize=$l, rightSize=$r")
if (l == r && rightRowCount.isDefined && leftRowCount.isDefined) {
if (rightRowCount.get <= leftRowCount.get) {
Some(BuildRight)
} else Some(BuildLeft)
} else if (r <= l) {
Some(BuildRight)
} else {
Some(BuildLeft)
}
case _ => None
}
if (rightSize <= leftSize) {
return BuildRight
}

/**
* Recursively find runtime Statistics from the physical plan.
* - QueryStageExec: directly returns materialized runtime stats
* - Single-child wrappers (e.g. AQEShuffleRead): pass through to the child
* - Join operators: estimate output from children's runtime stats
*/
private def findRuntimeStats(plan: SparkPlan): Option[Statistics] = plan match {
case stage: QueryStageExec => stage.computeStats
case p if p.children.length == 1 => findRuntimeStats(p.children.head)
case j: BaseJoinExec => estimateJoinOutputStats(j)
case _ => None
}

/**
* Estimate the output Statistics of an inline join from children's runtime stats. The outputRows
* estimation follows Spark's JoinEstimation.
*/
private def estimateJoinOutputStats(join: BaseJoinExec): Option[Statistics] = {
val leftStats = findRuntimeStats(join.left)
val rightStats = findRuntimeStats(join.right)
logDebug(s"Estimating ${join.getClass.getSimpleName}(${join.joinType})")

(leftStats, rightStats) match {
case (
Some(Statistics(lBytes, Some(lRows), _, _)),
Some(Statistics(rBytes, Some(rRows), _, _))) =>
val lAvgRowSize = if (lRows > 0) lBytes.toDouble / lRows.toDouble else 0.0
val rAvgRowSize = if (rRows > 0) rBytes.toDouble / rRows.toDouble else 0.0
val outputRowSize = lAvgRowSize + rAvgRowSize

val numInnerJoinedRows =
if (join.leftKeys.nonEmpty) lRows.min(rRows) else lRows * rRows

val (estRows, estBytes) = join.joinType match {
case LeftSemi | LeftAnti =>
(lRows, lBytes)
case LeftOuter =>
val rows = lRows.max(numInnerJoinedRows)
(rows, BigInt((rows.toDouble * outputRowSize).toLong))
case RightOuter =>
val rows = rRows.max(numInnerJoinedRows)
(rows, BigInt((rows.toDouble * outputRowSize).toLong))
case FullOuter =>
val rows = lRows.max(numInnerJoinedRows) +
rRows.max(numInnerJoinedRows) - numInnerJoinedRows
(rows, BigInt((rows.toDouble * outputRowSize).toLong))
case _ =>
(numInnerJoinedRows, BigInt((numInnerJoinedRows.toDouble * outputRowSize).toLong))
}
logDebug(s"Estimated join output: rows=$estRows, bytes=$estBytes")
Some(Statistics(estBytes, Some(estRows)))
case _ => None
}
BuildLeft
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension.columnar.offload.OffloadJoin

import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}

Expand All @@ -45,17 +44,11 @@ object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper {
if (!rightBuildable) {
return Some(BuildLeft)
}
val side = join.logicalLink
.flatMap {
case join: Join => Some(OffloadJoin.getOptimalBuildSide(join))
case _ => None
}
.getOrElse {
// If smj has no logical link, or its logical link is not a join,
// then we always choose left as build side.
BuildLeft
}
Some(side)
OffloadJoin
.getOptimalBuildSide(join)
// If neither runtime stats nor logical link is available,
// then we always choose left as build side.
.orElse(Some(BuildLeft))
}

override def rewrite(plan: SparkPlan): SparkPlan = plan match {
Expand Down
Loading