spark SortMergeJoinExec 源码
spark SortMergeJoinExec 代码
文件路径:/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.sql.execution.joins
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.util.collection.BitSet
/**
 * Performs a sort merge join of two child relations.
 */
case class SortMergeJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan,
    isSkewJoin: Boolean = false) extends ShuffledJoin {
  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
  override def outputOrdering: Seq[SortOrder] = joinType match {
    // For inner join, orders of both sides keys should be kept.
    case _: InnerLike =>
      val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
      val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
      leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
        // Also add expressions from right side sort order
        val sameOrderExpressions = ExpressionSet(lKey.sameOrderExpressions ++ rKey.children)
        SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq)
      }
    // For left and right outer joins, the output is ordered by the streamed input's join keys.
    case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
    case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering)
    // There are null rows in both streams, so there is no order.
    case FullOuter => Nil
    case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering)
    case x =>
      throw new IllegalArgumentException(
        s"${getClass.getSimpleName} should not take $x as the JoinType")
  }
  /**
   * The utility method to get output ordering for left or right side of the join.
   *
   * Returns the required ordering for left or right child if childOutputOrdering does not
   * satisfy the required ordering; otherwise, which means the child does not need to be sorted
   * again, returns the required ordering for this child with extra "sameOrderExpressions" from
   * the child's outputOrdering.
   */
  private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder])
    : Seq[SortOrder] = {
    val requiredOrdering = requiredOrders(keys)
    if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
      keys.zip(childOutputOrdering).map { case (key, childOrder) =>
        val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key
        SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
      }
    } else {
      requiredOrdering
    }
  }
  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
    requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
  private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
    // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
    keys.map(SortOrder(_, Ascending))
  }
  private def createLeftKeyGenerator(): Projection =
    UnsafeProjection.create(leftKeys, left.output)
  private def createRightKeyGenerator(): Projection =
    UnsafeProjection.create(rightKeys, right.output)
  private def getSpillThreshold: Int = {
    conf.sortMergeJoinExecBufferSpillThreshold
  }
  // Flag to only buffer first matched row, to avoid buffering unnecessary rows.
  private val onlyBufferFirstMatchedRow = (joinType, condition) match {
    case (LeftExistence(_), None) => true
    case _ => false
  }
  private def getInMemoryThreshold: Int = {
    if (onlyBufferFirstMatchedRow) {
      1
    } else {
      conf.sortMergeJoinExecBufferInMemoryThreshold
    }
  }
  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    val spillSize = longMetric("spillSize")
    val spillThreshold = getSpillThreshold
    val inMemoryThreshold = getInMemoryThreshold
    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
      val boundCondition: (InternalRow) => Boolean = {
        condition.map { cond =>
          Predicate.create(cond, left.output ++ right.output).eval _
        }.getOrElse {
          (r: InternalRow) => true
        }
      }
      // An ordering that can be used to compare keys from both sides.
      val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType))
      val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
      joinType match {
        case _: InnerLike =>
          new RowIterator {
            private[this] var currentLeftRow: InternalRow = _
            private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
            private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
            private[this] val smjScanner = new SortMergeJoinScanner(
              createLeftKeyGenerator(),
              createRightKeyGenerator(),
              keyOrdering,
              RowIterator.fromScala(leftIter),
              RowIterator.fromScala(rightIter),
              inMemoryThreshold,
              spillThreshold,
              spillSize,
              cleanupResources
            )
            private[this] val joinRow = new JoinedRow
            if (smjScanner.findNextInnerJoinRows()) {
              currentRightMatches = smjScanner.getBufferedMatches
              currentLeftRow = smjScanner.getStreamedRow
              rightMatchesIterator = currentRightMatches.generateIterator()
            }
            override def advanceNext(): Boolean = {
              while (rightMatchesIterator != null) {
                if (!rightMatchesIterator.hasNext) {
                  if (smjScanner.findNextInnerJoinRows()) {
                    currentRightMatches = smjScanner.getBufferedMatches
                    currentLeftRow = smjScanner.getStreamedRow
                    rightMatchesIterator = currentRightMatches.generateIterator()
                  } else {
                    currentRightMatches = null
                    currentLeftRow = null
                    rightMatchesIterator = null
                    return false
                  }
                }
                joinRow(currentLeftRow, rightMatchesIterator.next())
                if (boundCondition(joinRow)) {
                  numOutputRows += 1
                  return true
                }
              }
              false
            }
            override def getRow: InternalRow = resultProj(joinRow)
          }.toScala
        case LeftOuter =>
          val smjScanner = new SortMergeJoinScanner(
            streamedKeyGenerator = createLeftKeyGenerator(),
            bufferedKeyGenerator = createRightKeyGenerator(),
            keyOrdering,
            streamedIter = RowIterator.fromScala(leftIter),
            bufferedIter = RowIterator.fromScala(rightIter),
            inMemoryThreshold,
            spillThreshold,
            spillSize,
            cleanupResources
          )
          val rightNullRow = new GenericInternalRow(right.output.length)
          new LeftOuterIterator(
            smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
        case RightOuter =>
          val smjScanner = new SortMergeJoinScanner(
            streamedKeyGenerator = createRightKeyGenerator(),
            bufferedKeyGenerator = createLeftKeyGenerator(),
            keyOrdering,
            streamedIter = RowIterator.fromScala(rightIter),
            bufferedIter = RowIterator.fromScala(leftIter),
            inMemoryThreshold,
            spillThreshold,
            spillSize,
            cleanupResources
          )
          val leftNullRow = new GenericInternalRow(left.output.length)
          new RightOuterIterator(
            smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
        case FullOuter =>
          val leftNullRow = new GenericInternalRow(left.output.length)
          val rightNullRow = new GenericInternalRow(right.output.length)
          val smjScanner = new SortMergeFullOuterJoinScanner(
            leftKeyGenerator = createLeftKeyGenerator(),
            rightKeyGenerator = createRightKeyGenerator(),
            keyOrdering,
            leftIter = RowIterator.fromScala(leftIter),
            rightIter = RowIterator.fromScala(rightIter),
            boundCondition,
            leftNullRow,
            rightNullRow)
          new FullOuterIterator(
            smjScanner,
            resultProj,
            numOutputRows).toScala
        case LeftSemi =>
          new RowIterator {
            private[this] var currentLeftRow: InternalRow = _
            private[this] val smjScanner = new SortMergeJoinScanner(
              createLeftKeyGenerator(),
              createRightKeyGenerator(),
              keyOrdering,
              RowIterator.fromScala(leftIter),
              RowIterator.fromScala(rightIter),
              inMemoryThreshold,
              spillThreshold,
              spillSize,
              cleanupResources,
              onlyBufferFirstMatchedRow
            )
            private[this] val joinRow = new JoinedRow
            override def advanceNext(): Boolean = {
              while (smjScanner.findNextInnerJoinRows()) {
                val currentRightMatches = smjScanner.getBufferedMatches
                currentLeftRow = smjScanner.getStreamedRow
                if (currentRightMatches != null && currentRightMatches.length > 0) {
                  val rightMatchesIterator = currentRightMatches.generateIterator()
                  while (rightMatchesIterator.hasNext) {
                    joinRow(currentLeftRow, rightMatchesIterator.next())
                    if (boundCondition(joinRow)) {
                      numOutputRows += 1
                      return true
                    }
                  }
                }
              }
              false
            }
            override def getRow: InternalRow = currentLeftRow
          }.toScala
        case LeftAnti =>
          new RowIterator {
            private[this] var currentLeftRow: InternalRow = _
            private[this] val smjScanner = new SortMergeJoinScanner(
              createLeftKeyGenerator(),
              createRightKeyGenerator(),
              keyOrdering,
              RowIterator.fromScala(leftIter),
              RowIterator.fromScala(rightIter),
              inMemoryThreshold,
              spillThreshold,
              spillSize,
              cleanupResources,
              onlyBufferFirstMatchedRow
            )
            private[this] val joinRow = new JoinedRow
            override def advanceNext(): Boolean = {
              while (smjScanner.findNextOuterJoinRows()) {
                currentLeftRow = smjScanner.getStreamedRow
                val currentRightMatches = smjScanner.getBufferedMatches
                if (currentRightMatches == null || currentRightMatches.length == 0) {
                  numOutputRows += 1
                  return true
                }
                var found = false
                val rightMatchesIterator = currentRightMatches.generateIterator()
                while (!found && rightMatchesIterator.hasNext) {
                  joinRow(currentLeftRow, rightMatchesIterator.next())
                  if (boundCondition(joinRow)) {
                    found = true
                  }
                }
                if (!found) {
                  numOutputRows += 1
                  return true
                }
              }
              false
            }
            override def getRow: InternalRow = currentLeftRow
          }.toScala
        case j: ExistenceJoin =>
          new RowIterator {
            private[this] var currentLeftRow: InternalRow = _
            private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null))
            private[this] val smjScanner = new SortMergeJoinScanner(
              createLeftKeyGenerator(),
              createRightKeyGenerator(),
              keyOrdering,
              RowIterator.fromScala(leftIter),
              RowIterator.fromScala(rightIter),
              inMemoryThreshold,
              spillThreshold,
              spillSize,
              cleanupResources,
              onlyBufferFirstMatchedRow
            )
            private[this] val joinRow = new JoinedRow
            override def advanceNext(): Boolean = {
              while (smjScanner.findNextOuterJoinRows()) {
                currentLeftRow = smjScanner.getStreamedRow
                val currentRightMatches = smjScanner.getBufferedMatches
                var found = false
                if (currentRightMatches != null && currentRightMatches.length > 0) {
                  val rightMatchesIterator = currentRightMatches.generateIterator()
                  while (!found && rightMatchesIterator.hasNext) {
                    joinRow(currentLeftRow, rightMatchesIterator.next())
                    if (boundCondition(joinRow)) {
                      found = true
                    }
                  }
                }
                result.setBoolean(0, found)
                numOutputRows += 1
                return true
              }
              false
            }
            override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result))
          }.toScala
        case x =>
          throw new IllegalArgumentException(
            s"SortMergeJoin should not take $x as the JoinType")
      }
    }
  }
  private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
    case _: InnerLike | LeftOuter | FullOuter | LeftExistence(_) =>
      ((left, leftKeys), (right, rightKeys))
    case RightOuter => ((right, rightKeys), (left, leftKeys))
    case x =>
      throw new IllegalArgumentException(
        s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType")
  }
  private lazy val streamedOutput = streamedPlan.output
  private lazy val bufferedOutput = bufferedPlan.output
  override def supportCodegen: Boolean = joinType match {
    case FullOuter => conf.getConf(SQLConf.ENABLE_FULL_OUTER_SORT_MERGE_JOIN_CODEGEN)
    case _: ExistenceJoin => conf.getConf(SQLConf.ENABLE_EXISTENCE_SORT_MERGE_JOIN_CODEGEN)
    case _ => true
  }
  override def inputRDDs(): Seq[RDD[InternalRow]] = {
    streamedPlan.execute() :: bufferedPlan.execute() :: Nil
  }
  private def createJoinKey(
      ctx: CodegenContext,
      row: String,
      keys: Seq[Expression],
      input: Seq[Attribute]): Seq[ExprCode] = {
    ctx.INPUT_ROW = row
    ctx.currentVars = null
    bindReferences(keys, input).map(_.genCode(ctx))
  }
  private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
    vars.zipWithIndex.map { case (ev, i) =>
      ctx.addBufferedState(leftKeys(i).dataType, "value", ev.value)
    }
  }
  private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
    val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
      s"""
         |if (comp == 0) {
         |  comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
         |}
       """.stripMargin.trim
    }
    s"""
       |comp = 0;
       |${comparisons.mkString("\n")}
     """.stripMargin
  }
  /**
   * Generate a function to scan both sides to find a match, returns:
   * 1. the function name
   * 2. the term for matched one row from streamed side
   * 3. the term for buffered rows from buffered side
   */
  private def genScanner(ctx: CodegenContext): (String, String, String) = {
    // Create class member for next row from both sides.
    // Inline mutable state since not many join operations in a task
    val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true)
    val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", forceInline = true)
    // Create variables for join keys from both sides.
    val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput)
    val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
    val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput)
    val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
    // Copy the buffered key as class members so they could be used in next function call.
    val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)
    // A list to hold all matched rows from buffered side.
    val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
    val spillThreshold = getSpillThreshold
    val inMemoryThreshold = getInMemoryThreshold
    // Inline mutable state since not many join operations in a task
    val matches = ctx.addMutableState(clsName, "matches",
      v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
    // Copy the streamed keys as class members so they could be used in next function call.
    val matchedKeyVars = copyKeys(ctx, streamedKeyVars)
    // Handle the case when streamed rows has any NULL keys.
    val handleStreamedAnyNull = joinType match {
      case _: InnerLike | LeftSemi =>
        // Skip streamed row.
        s"""
           |$streamedRow = null;
           |continue;
         """.stripMargin
      case LeftOuter | RightOuter | LeftAnti | ExistenceJoin(_) =>
        // Eagerly return streamed row. Only call `matches.clear()` when `matches.isEmpty()` is
        // false, to reduce unnecessary computation.
        s"""
           |if (!$matches.isEmpty()) {
           |  $matches.clear();
           |}
           |return false;
         """.stripMargin
      case x =>
        throw new IllegalArgumentException(
          s"SortMergeJoin.genScanner should not take $x as the JoinType")
    }
    // Handle the case when streamed keys has no match with buffered side.
    val handleStreamedWithoutMatch = joinType match {
      case _: InnerLike | LeftSemi =>
        // Skip streamed row.
        s"$streamedRow = null;"
      case LeftOuter | RightOuter | LeftAnti | ExistenceJoin(_) =>
        // Eagerly return with streamed row.
        "return false;"
      case x =>
        throw new IllegalArgumentException(
          s"SortMergeJoin.genScanner should not take $x as the JoinType")
    }
    val addRowToBuffer =
      if (onlyBufferFirstMatchedRow) {
        s"""
           |if ($matches.isEmpty()) {
           |  $matches.add((UnsafeRow) $bufferedRow);
           |}
         """.stripMargin
      } else {
        s"$matches.add((UnsafeRow) $bufferedRow);"
      }
    // Generate a function to scan both streamed and buffered sides to find a match.
    // Return whether a match is found.
    //
    // `streamedIter`: the iterator for streamed side.
    // `bufferedIter`: the iterator for buffered side.
    // `streamedRow`: the current row from streamed side.
    //                When `streamedIter` is empty, `streamedRow` is null.
    // `matches`: the rows from buffered side already matched with `streamedRow`.
    //            `matches` is buffered and reused for all `streamedRow`s having same join keys.
    //            If there is no match with `streamedRow`, `matches` is empty.
    // `bufferedRow`: the current matched row from buffered side.
    //
    // The function has the following step:
    //  - Step 1: Find the next `streamedRow` with non-null join keys.
    //            For `streamedRow` with null join keys (`handleStreamedAnyNull`):
    //            1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
    //                                         hitting the next `streamedRow` with non-null join
    //                                         keys.
    //            2. Left/Right Outer, Left Anti and Existence join: clear the previous `matches`
    //                                                               if needed, keep the row, and
    //                                                               return false.
    //
    //  - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
    //            Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
    //            Use `bufferedRow` to iterate buffered side to put all matched rows into
    //            `matches` (`addRowToBuffer`). Return true when getting all matched rows.
    //            For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
    //            1. Inner and Left Semi join: skip the row.
    //            2. Left/Right Outer, Left Anti and Existence join: keep the row and return false
    //                                                               (with `matches` being empty).
    val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
    ctx.addNewFunction(findNextJoinRowsFuncName,
      s"""
         |private boolean $findNextJoinRowsFuncName(
         |    scala.collection.Iterator streamedIter,
         |    scala.collection.Iterator bufferedIter) {
         |  $streamedRow = null;
         |  int comp = 0;
         |  while ($streamedRow == null) {
         |    if (!streamedIter.hasNext()) return false;
         |    $streamedRow = (InternalRow) streamedIter.next();
         |    ${streamedKeyVars.map(_.code).mkString("\n")}
         |    if ($streamedAnyNull) {
         |      $handleStreamedAnyNull
         |    }
         |    if (!$matches.isEmpty()) {
         |      ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
         |      if (comp == 0) {
         |        return true;
         |      }
         |      $matches.clear();
         |    }
         |
         |    do {
         |      if ($bufferedRow == null) {
         |        if (!bufferedIter.hasNext()) {
         |          ${matchedKeyVars.map(_.code).mkString("\n")}
         |          return !$matches.isEmpty();
         |        }
         |        $bufferedRow = (InternalRow) bufferedIter.next();
         |        ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
         |        if ($bufferedAnyNull) {
         |          $bufferedRow = null;
         |          continue;
         |        }
         |        ${bufferedKeyVars.map(_.code).mkString("\n")}
         |      }
         |      ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
         |      if (comp > 0) {
         |        $bufferedRow = null;
         |      } else if (comp < 0) {
         |        if (!$matches.isEmpty()) {
         |          ${matchedKeyVars.map(_.code).mkString("\n")}
         |          return true;
         |        } else {
         |          $handleStreamedWithoutMatch
         |        }
         |      } else {
         |        $addRowToBuffer
         |        $bufferedRow = null;
         |      }
         |    } while ($streamedRow != null);
         |  }
         |  return false; // unreachable
         |}
       """.stripMargin, inlineToOuterClass = true)
    (findNextJoinRowsFuncName, streamedRow, matches)
  }
  /**
   * Creates variables and declarations for streamed part of result row.
   *
   * In order to defer the access after condition and also only access once in the loop,
   * the variables should be declared separately from accessing the columns, we can't use the
   * codegen of BoundReference here.
   */
  private def createStreamedVars(
      ctx: CodegenContext,
      streamedRow: String): (Seq[ExprCode], Seq[String]) = {
    ctx.INPUT_ROW = streamedRow
    streamedPlan.output.zipWithIndex.map { case (a, i) =>
      val value = ctx.freshName("value")
      val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString)
      val javaType = CodeGenerator.javaType(a.dataType)
      val defaultValue = CodeGenerator.defaultValue(a.dataType)
      if (a.nullable) {
        val isNull = ctx.freshName("isNull")
        val code =
          code"""
             |$isNull = $streamedRow.isNullAt($i);
             |$value = $isNull ? $defaultValue : ($valueCode);
           """.stripMargin
        val streamedVarsDecl =
          s"""
             |boolean $isNull = false;
             |$javaType $value = $defaultValue;
           """.stripMargin
        (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
          streamedVarsDecl)
      } else {
        val code = code"$value = $valueCode;"
        val streamedVarsDecl = s"""$javaType $value = $defaultValue;"""
        (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), streamedVarsDecl)
      }
    }.unzip
  }
  /**
   * Splits variables based on whether it's used by condition or not, returns the code to create
   * these variables before the condition and after the condition.
   *
   * Only a few columns are used by condition, then we can skip the accessing of those columns
   * that are not used by condition also filtered out by condition.
   */
  private def splitVarsByCondition(
      attributes: Seq[Attribute],
      variables: Seq[ExprCode]): (String, String) = {
    if (condition.isDefined) {
      val condRefs = condition.get.references
      val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
        condRefs.contains(a)
      }
      val beforeCond = evaluateVariables(used.map(_._2))
      val afterCond = evaluateVariables(notUsed.map(_._2))
      (beforeCond, afterCond)
    } else {
      (evaluateVariables(variables), "")
    }
  }
  override def needCopyResult: Boolean = true
  /**
   * This is called by generated Java class, should be public.
   */
  def getTaskContext(): TaskContext = {
    TaskContext.get()
  }
  override def doProduce(ctx: CodegenContext): String = {
    // Specialize `doProduce` code for full outer join, because full outer join needs to
    // buffer both sides of join.
    if (joinType == FullOuter) {
      return codegenFullOuter(ctx)
    }
    // Inline mutable state since not many join operations in a task
    val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
      v => s"$v = inputs[0];", forceInline = true)
    val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput",
      v => s"$v = inputs[1];", forceInline = true)
    val (findNextJoinRowsFuncName, streamedRow, matches) = genScanner(ctx)
    // Create variables for row from both sides.
    val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow)
    val bufferedRow = ctx.freshName("bufferedRow")
    val setDefaultValue = joinType == LeftOuter || joinType == RightOuter
    val bufferedVars = genOneSideJoinVars(ctx, bufferedRow, bufferedPlan, setDefaultValue)
    // Create variable name for Existence join.
    val existsVar = joinType match {
      case ExistenceJoin(_) => Some(ctx.freshName("exists"))
      case _ => None
    }
    val iterator = ctx.freshName("iterator")
    val numOutput = metricTerm(ctx, "numOutputRows")
    val resultVars = joinType match {
      case _: InnerLike | LeftOuter =>
        streamedVars ++ bufferedVars
      case RightOuter =>
        bufferedVars ++ streamedVars
      case LeftSemi | LeftAnti =>
        streamedVars
      case ExistenceJoin(_) =>
        streamedVars ++ Seq(ExprCode.forNonNullValue(
          JavaCode.variable(existsVar.get, BooleanType)))
      case x =>
        throw new IllegalArgumentException(
          s"SortMergeJoin.doProduce should not take $x as the JoinType")
    }
    val (streamedBeforeLoop, condCheck, loadStreamed) = if (condition.isDefined) {
      // Split the code of creating variables based on whether it's used by condition or not.
      val loaded = ctx.freshName("loaded")
      val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
      val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
      // Generate code for condition
      ctx.currentVars = streamedVars ++ bufferedVars
      val cond = BindReferences.bindReference(
        condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx)
      // Evaluate the columns those used by condition before loop
      val before = joinType match {
        case LeftAnti =>
          // No need to initialize `loaded` variable for Left Anti join.
          streamedBefore.trim
        case _ =>
          s"""
             |boolean $loaded = false;
             |$streamedBefore
         """.stripMargin
      }
      val loadStreamedAfterCondition = joinType match {
        case LeftAnti =>
          // No need to evaluate columns not used by condition from streamed side, as for Left Anti
          // join, streamed row with match is not outputted.
          ""
        case _ =>
          s"""
             |if (!$loaded) {
             |  $loaded = true;
             |  $streamedAfter
             |}
         """.stripMargin
      }
      val loadBufferedAfterCondition = joinType match {
        case LeftExistence(_) =>
          // No need to evaluate columns not used by condition from buffered side
          ""
        case _ => bufferedAfter
      }
      val checking =
        s"""
           |$bufferedBefore
           |if ($bufferedRow != null) {
           |  ${cond.code}
           |  if (${cond.isNull} || !${cond.value}) {
           |    continue;
           |  }
           |}
           |$loadStreamedAfterCondition
           |$loadBufferedAfterCondition
         """.stripMargin
      (before, checking.trim, streamedAfter.trim)
    } else {
      (evaluateVariables(streamedVars), "", "")
    }
    val beforeLoop =
      s"""
         |${streamedVarDecl.mkString("\n")}
         |${streamedBeforeLoop.trim}
         |scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
       """.stripMargin
    val outputRow =
      s"""
         |$numOutput.add(1);
         |${consume(ctx, resultVars)}
       """.stripMargin
    val findNextJoinRows = s"$findNextJoinRowsFuncName($streamedInput, $bufferedInput)"
    val thisPlan = ctx.addReferenceObj("plan", this)
    val eagerCleanup = s"$thisPlan.cleanupResources();"
    val doJoin = joinType match {
      case _: InnerLike =>
        codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow,
          eagerCleanup)
      case LeftOuter | RightOuter =>
        codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
          ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
      case LeftSemi =>
        codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
          ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
      case LeftAnti =>
        codegenAnti(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
          loadStreamed, ctx.freshName("hasMatchedRow"), outputRow, eagerCleanup)
      case ExistenceJoin(_) =>
        codegenExistence(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow,
          condCheck, loadStreamed, existsVar.get, outputRow, eagerCleanup)
      case x =>
        throw new IllegalArgumentException(
          s"SortMergeJoin.doProduce should not take $x as the JoinType")
    }
    val initJoin = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initJoin")
    val addHookToRecordMetrics =
      s"""
         |$thisPlan.getTaskContext().addTaskCompletionListener(
         |  new org.apache.spark.util.TaskCompletionListener() {
         |    @Override
         |    public void onTaskCompletion(org.apache.spark.TaskContext context) {
         |      ${metricTerm(ctx, "spillSize")}.add($matches.spillSize());
         |    }
         |});
       """.stripMargin
    s"""
       |if (!$initJoin) {
       |  $initJoin = true;
       |  $addHookToRecordMetrics
       |}
       |$doJoin
     """.stripMargin
  }
  /**
   * Generates the code for Inner join.
   */
  private def codegenInner(
      findNextJoinRows: String,
      beforeLoop: String,
      matchIterator: String,
      bufferedRow: String,
      conditionCheck: String,
      outputRow: String,
      eagerCleanup: String): String = {
    s"""
       |while ($findNextJoinRows) {
       |  $beforeLoop
       |  while ($matchIterator.hasNext()) {
       |    InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
       |    $conditionCheck
       |    $outputRow
       |  }
       |  if (shouldStop()) return;
       |}
       |$eagerCleanup
     """.stripMargin
  }
  /**
   * Generates the code for Left or Right Outer join.
   */
  private def codegenOuter(
      streamedInput: String,
      findNextJoinRows: String,
      beforeLoop: String,
      matchIterator: String,
      bufferedRow: String,
      conditionCheck: String,
      hasOutputRow: String,
      outputRow: String,
      eagerCleanup: String): String = {
    s"""
       |while ($streamedInput.hasNext()) {
       |  $findNextJoinRows;
       |  $beforeLoop
       |  boolean $hasOutputRow = false;
       |
       |  // the last iteration of this loop is to emit an empty row if there is no matched rows.
       |  while ($matchIterator.hasNext() || !$hasOutputRow) {
       |    InternalRow $bufferedRow = $matchIterator.hasNext() ?
       |      (InternalRow) $matchIterator.next() : null;
       |    $conditionCheck
       |    $hasOutputRow = true;
       |    $outputRow
       |  }
       |  if (shouldStop()) return;
       |}
       |$eagerCleanup
     """.stripMargin
  }
  /**
   * Generates the code for Left Semi join.
   */
  private def codegenSemi(
      findNextJoinRows: String,
      beforeLoop: String,
      matchIterator: String,
      bufferedRow: String,
      conditionCheck: String,
      hasOutputRow: String,
      outputRow: String,
      eagerCleanup: String): String = {
    s"""
       |while ($findNextJoinRows) {
       |  $beforeLoop
       |  boolean $hasOutputRow = false;
       |
       |  while (!$hasOutputRow && $matchIterator.hasNext()) {
       |    InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
       |    $conditionCheck
       |    $hasOutputRow = true;
       |    $outputRow
       |  }
       |  if (shouldStop()) return;
       |}
       |$eagerCleanup
     """.stripMargin
  }
  /**
   * Generates the code for Left Anti join.
   */
  private def codegenAnti(
      streamedInput: String,
      findNextJoinRows: String,
      beforeLoop: String,
      matchIterator: String,
      bufferedRow: String,
      conditionCheck: String,
      loadStreamed: String,
      hasMatchedRow: String,
      outputRow: String,
      eagerCleanup: String): String = {
    s"""
       |while ($streamedInput.hasNext()) {
       |  $findNextJoinRows;
       |  $beforeLoop
       |  boolean $hasMatchedRow = false;
       |
       |  while (!$hasMatchedRow && $matchIterator.hasNext()) {
       |    InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
       |    $conditionCheck
       |    $hasMatchedRow = true;
       |  }
       |
       |  if (!$hasMatchedRow) {
       |    // load all values of streamed row, because the values not in join condition are not
       |    // loaded yet.
       |    $loadStreamed
       |    $outputRow
       |  }
       |  if (shouldStop()) return;
       |}
       |$eagerCleanup
     """.stripMargin
  }
  /**
   * Generates the code for Existence join.
   */
  private def codegenExistence(
      streamedInput: String,
      findNextJoinRows: String,
      beforeLoop: String,
      matchIterator: String,
      bufferedRow: String,
      conditionCheck: String,
      loadStreamed: String,
      exists: String,
      outputRow: String,
      eagerCleanup: String): String = {
    s"""
       |while ($streamedInput.hasNext()) {
       |  $findNextJoinRows;
       |  $beforeLoop
       |  boolean $exists = false;
       |
       |  while (!$exists && $matchIterator.hasNext()) {
       |    InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
       |    $conditionCheck
       |    $exists = true;
       |  }
       |
       |  if (!$exists) {
       |    // load all values of streamed row, because the values not in join condition are not
       |    // loaded yet.
       |    $loadStreamed
       |  }
       |  $outputRow
       |
       |  if (shouldStop()) return;
       |}
       |$eagerCleanup
     """.stripMargin
  }
  /**
   * Generates the code for Full Outer join.
   */
  private def codegenFullOuter(ctx: CodegenContext): String = {
    // Inline mutable state since not many join operations in a task.
    // Create class member for input iterator from both sides.
    val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
      v => s"$v = inputs[0];", forceInline = true)
    val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
      v => s"$v = inputs[1];", forceInline = true)
    // Create class member for next input row from both sides.
    val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow", forceInline = true)
    val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow", forceInline = true)
    // Create variables for join keys from both sides.
    val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
    val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
    val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
    val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ")
    val matchedKeyVars = copyKeys(ctx, leftKeyVars)
    val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
    val rightMatchedKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
    // Create class member for next output row from both sides.
    val leftOutputRow = ctx.addMutableState("InternalRow", "leftOutputRow", forceInline = true)
    val rightOutputRow = ctx.addMutableState("InternalRow", "rightOutputRow", forceInline = true)
    // Create class member for buffers of rows with same join keys from both sides.
    val bufferClsName = "java.util.ArrayList<InternalRow>"
    val leftBuffer = ctx.addMutableState(bufferClsName, "leftBuffer",
      v => s"$v = new $bufferClsName();", forceInline = true)
    val rightBuffer = ctx.addMutableState(bufferClsName, "rightBuffer",
      v => s"$v = new $bufferClsName();", forceInline = true)
    val matchedClsName = classOf[BitSet].getName
    val leftMatched = ctx.addMutableState(matchedClsName, "leftMatched",
      v => s"$v = new $matchedClsName(1);", forceInline = true)
    val rightMatched = ctx.addMutableState(matchedClsName, "rightMatched",
      v => s"$v = new $matchedClsName(1);", forceInline = true)
    val leftIndex = ctx.freshName("leftIndex")
    val rightIndex = ctx.freshName("rightIndex")
    // Generate code for join condition
    val leftResultVars = genOneSideJoinVars(
      ctx, leftOutputRow, left, setDefaultValue = true)
    val rightResultVars = genOneSideJoinVars(
      ctx, rightOutputRow, right, setDefaultValue = true)
    val resultVars = leftResultVars ++ rightResultVars
    val (_, conditionCheck, _) =
      getJoinCondition(ctx, leftResultVars, left, right, Some(rightOutputRow))
    // Generate code for result output in separate function, as we need to output result from
    // multiple places in join code.
    val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
    ctx.addNewFunction(consumeFullOuterJoinRow,
      s"""
         |private void $consumeFullOuterJoinRow() throws java.io.IOException {
         |  ${metricTerm(ctx, "numOutputRows")}.add(1);
         |  ${consume(ctx, resultVars)}
         |}
       """.stripMargin)
    // Handle the case when input row has no match.
    val outputLeftNoMatch =
      s"""
         |$leftOutputRow = $leftInputRow;
         |$rightOutputRow = null;
         |$leftInputRow = null;
         |$consumeFullOuterJoinRow();
       """.stripMargin
    val outputRightNoMatch =
      s"""
         |$rightOutputRow = $rightInputRow;
         |$leftOutputRow = null;
         |$rightInputRow = null;
         |$consumeFullOuterJoinRow();
       """.stripMargin
    // Generate a function to scan both sides to find rows with matched join keys.
    // The matched rows from both sides are copied in buffers separately. This function assumes
    // either non-empty `leftIter` and `rightIter`, or non-null `leftInputRow` and `rightInputRow`.
    //
    // The function has the following steps:
    //  - Step 1: Find the next `leftInputRow` and `rightInputRow` with non-null join keys.
    //            Output row with null join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
    //
    //  - Step 2: Compare and find next same join keys from between `leftInputRow` and
    //            `rightInputRow`.
    //            Output row with smaller join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
    //
    //  - Step 3: Buffer rows with same join keys from both sides into `leftBuffer` and
    //            `rightBuffer`. Reset bit sets for both buffers accordingly (`leftMatched` and
    //            `rightMatched`).
    val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
    ctx.addNewFunction(findNextJoinRowsFuncName,
      s"""
         |private void $findNextJoinRowsFuncName(
         |    scala.collection.Iterator leftIter,
         |    scala.collection.Iterator rightIter) throws java.io.IOException {
         |  int comp = 0;
         |  $leftBuffer.clear();
         |  $rightBuffer.clear();
         |
         |  if ($leftInputRow == null) {
         |    $leftInputRow = (InternalRow) leftIter.next();
         |  }
         |  if ($rightInputRow == null) {
         |    $rightInputRow = (InternalRow) rightIter.next();
         |  }
         |
         |  ${leftKeyVars.map(_.code).mkString("\n")}
         |  if ($leftAnyNull) {
         |    // The left row join key is null, join it with null row
         |    $outputLeftNoMatch
         |    return;
         |  }
         |
         |  ${rightKeyVars.map(_.code).mkString("\n")}
         |  if ($rightAnyNull) {
         |    // The right row join key is null, join it with null row
         |    $outputRightNoMatch
         |    return;
         |  }
         |
         |  ${genComparison(ctx, leftKeyVars, rightKeyVars)}
         |  if (comp < 0) {
         |    // The left row join key is smaller, join it with null row
         |    $outputLeftNoMatch
         |    return;
         |  } else if (comp > 0) {
         |    // The right row join key is smaller, join it with null row
         |    $outputRightNoMatch
         |    return;
         |  }
         |
         |  ${matchedKeyVars.map(_.code).mkString("\n")}
         |  $leftBuffer.add($leftInputRow.copy());
         |  $rightBuffer.add($rightInputRow.copy());
         |  $leftInputRow = null;
         |  $rightInputRow = null;
         |
         |  // Buffer rows from both sides with same join key
         |  while (leftIter.hasNext()) {
         |    $leftInputRow = (InternalRow) leftIter.next();
         |    ${leftMatchedKeyVars.map(_.code).mkString("\n")}
         |    ${genComparison(ctx, leftMatchedKeyVars, matchedKeyVars)}
         |    if (comp == 0) {
         |
         |      $leftBuffer.add($leftInputRow.copy());
         |      $leftInputRow = null;
         |    } else {
         |      break;
         |    }
         |  }
         |  while (rightIter.hasNext()) {
         |    $rightInputRow = (InternalRow) rightIter.next();
         |    ${rightMatchedKeyVars.map(_.code).mkString("\n")}
         |    ${genComparison(ctx, rightMatchedKeyVars, matchedKeyVars)}
         |    if (comp == 0) {
         |      $rightBuffer.add($rightInputRow.copy());
         |      $rightInputRow = null;
         |    } else {
         |      break;
         |    }
         |  }
         |
         |  // Reset bit sets of buffers accordingly
         |  if ($leftBuffer.size() <= $leftMatched.capacity()) {
         |    $leftMatched.clearUntil($leftBuffer.size());
         |  } else {
         |    $leftMatched = new $matchedClsName($leftBuffer.size());
         |  }
         |  if ($rightBuffer.size() <= $rightMatched.capacity()) {
         |    $rightMatched.clearUntil($rightBuffer.size());
         |  } else {
         |    $rightMatched = new $matchedClsName($rightBuffer.size());
         |  }
         |}
       """.stripMargin)
    // Scan the left and right buffers to find all matched rows.
    val matchRowsInBuffer =
      s"""
         |int $leftIndex;
         |int $rightIndex;
         |
         |for ($leftIndex = 0; $leftIndex < $leftBuffer.size(); $leftIndex++) {
         |  $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex);
         |  for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
         |    $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
         |    $conditionCheck {
         |      $consumeFullOuterJoinRow();
         |      $leftMatched.set($leftIndex);
         |      $rightMatched.set($rightIndex);
         |    }
         |  }
         |
         |  if (!$leftMatched.get($leftIndex)) {
         |
         |    $rightOutputRow = null;
         |    $consumeFullOuterJoinRow();
         |  }
         |}
         |
         |$leftOutputRow = null;
         |for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
         |  if (!$rightMatched.get($rightIndex)) {
         |    // The right row has never matched any left row, join it with null row
         |    $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
         |    $consumeFullOuterJoinRow();
         |  }
         |}
       """.stripMargin
    s"""
       |while (($leftInputRow != null || $leftInput.hasNext()) &&
       |  ($rightInputRow != null || $rightInput.hasNext())) {
       |  $findNextJoinRowsFuncName($leftInput, $rightInput);
       |  $matchRowsInBuffer
       |  if (shouldStop()) return;
       |}
       |
       |// The right iterator has no more rows, join left row with null
       |while ($leftInputRow != null || $leftInput.hasNext()) {
       |  if ($leftInputRow == null) {
       |    $leftInputRow = (InternalRow) $leftInput.next();
       |  }
       |  $outputLeftNoMatch
       |  if (shouldStop()) return;
       |}
       |
       |// The left iterator has no more rows, join right row with null
       |while ($rightInputRow != null || $rightInput.hasNext()) {
       |  if ($rightInputRow == null) {
       |    $rightInputRow = (InternalRow) $rightInput.next();
       |  }
       |  $outputRightNoMatch
       |  if (shouldStop()) return;
       |}
     """.stripMargin
  }
  override protected def withNewChildrenInternal(
      newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
    copy(left = newLeft, right = newRight)
}
/**
 * Helper class that is used to implement [[SortMergeJoinExec]].
 *
 * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]]
 * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false`
 * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return
 * the matching row from the streamed input and may call [[getBufferedMatches]] to return the
 * sequence of matching rows from the buffered input (in the case of an outer join, this will return
 * an empty sequence if there are no matches from the buffered input). For efficiency, both of these
 * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()`
 * methods.
 *
 * @param streamedKeyGenerator a projection that produces join keys from the streamed input.
 * @param bufferedKeyGenerator a projection that produces join keys from the buffered input.
 * @param keyOrdering an ordering which can be used to compare join keys.
 * @param streamedIter an input whose rows will be streamed.
 * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
 *                     have the same join key.
 * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by
 *                          internal buffer
 * @param spillThreshold Threshold for number of rows to be spilled by internal buffer
 * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found
 * @param onlyBufferFirstMatch [[bufferMatchingRows]] should buffer only the first matching row
 */
private[joins] class SortMergeJoinScanner(
    streamedKeyGenerator: Projection,
    bufferedKeyGenerator: Projection,
    keyOrdering: Ordering[InternalRow],
    streamedIter: RowIterator,
    bufferedIter: RowIterator,
    inMemoryThreshold: Int,
    spillThreshold: Int,
    spillSize: SQLMetric,
    eagerCleanupResources: () => Unit,
    onlyBufferFirstMatch: Boolean = false) {
  private[this] var streamedRow: InternalRow = _
  private[this] var streamedRowKey: InternalRow = _
  private[this] var bufferedRow: InternalRow = _
  // Note: this is guaranteed to never have any null columns:
  private[this] var bufferedRowKey: InternalRow = _
  /**
   * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty
   */
  private[this] var matchJoinKey: InternalRow = _
  /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
  private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray =
    new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
  // At the end of the task, update the task's spill size for buffered side.
  TaskContext.get().addTaskCompletionListener[Unit](_ => {
    spillSize += bufferedMatches.spillSize
  })
  // Initialization (note: do _not_ want to advance streamed here).
  advancedBufferedToRowWithNullFreeJoinKey()
  // --- Public methods ---------------------------------------------------------------------------
  def getStreamedRow: InternalRow = streamedRow
  def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches
  /**
   * Advances both input iterators, stopping when we have found rows with matching join keys. If no
   * join rows found, try to do the eager resources cleanup.
   * @return true if matching rows have been found and false otherwise. If this returns true, then
   *         [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join
   *         results.
   */
  final def findNextInnerJoinRows(): Boolean = {
    while (advancedStreamed() && streamedRowKey.anyNull) {
      // Advance the streamed side of the join until we find the next row whose join key contains
      // no nulls or we hit the end of the streamed iterator.
    }
    val found = if (streamedRow == null) {
      // We have consumed the entire streamed iterator, so there can be no more matches.
      matchJoinKey = null
      bufferedMatches.clear()
      false
    } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
      // The new streamed row has the same join key as the previous row, so return the same matches.
      true
    } else if (bufferedRow == null) {
      // The streamed row's join key does not match the current batch of buffered rows and there are
      // no more rows to read from the buffered iterator, so there can be no more matches.
      matchJoinKey = null
      bufferedMatches.clear()
      false
    } else {
      // Advance both the streamed and buffered iterators to find the next pair of matching rows.
      var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
      do {
        if (streamedRowKey.anyNull) {
          advancedStreamed()
        } else {
          assert(!bufferedRowKey.anyNull)
          comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
          if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
          else if (comp < 0) advancedStreamed()
        }
      } while (streamedRow != null && bufferedRow != null && comp != 0)
      if (streamedRow == null || bufferedRow == null) {
        // We have either hit the end of one of the iterators, so there can be no more matches.
        matchJoinKey = null
        bufferedMatches.clear()
        false
      } else {
        // The streamed row's join key matches the current buffered row's join, so walk through the
        // buffered iterator to buffer the rest of the matching rows.
        assert(comp == 0)
        bufferMatchingRows()
        true
      }
    }
    if (!found) eagerCleanupResources()
    found
  }
  /**
   * Advances the streamed input iterator and buffers all rows from the buffered input that
   * have matching keys. If no join rows found, try to do the eager resources cleanup.
   * @return true if the streamed iterator returned a row, false otherwise. If this returns true,
   *         then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer
   *         join results.
   */
  final def findNextOuterJoinRows(): Boolean = {
    val found = if (!advancedStreamed()) {
      // We have consumed the entire streamed iterator, so there can be no more matches.
      matchJoinKey = null
      bufferedMatches.clear()
      false
    } else {
      if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
        // Matches the current group, so do nothing.
      } else {
        // The streamed row does not match the current group.
        matchJoinKey = null
        bufferedMatches.clear()
        if (bufferedRow != null && !streamedRowKey.anyNull) {
          // The buffered iterator could still contain matching rows, so we'll need to walk through
          // it until we either find matches or pass where they would be found.
          var comp = 1
          do {
            comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
          } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey())
          if (comp == 0) {
            // We have found matches, so buffer them (this updates matchJoinKey)
            bufferMatchingRows()
          } else {
            // We have overshot the position where the row would be found, hence no matches.
          }
        }
      }
      // If there is a streamed input then we always return true
      true
    }
    if (!found) eagerCleanupResources()
    found
  }
  // --- Private methods --------------------------------------------------------------------------
  /**
   * Advance the streamed iterator and compute the new row's join key.
   * @return true if the streamed iterator returned a row and false otherwise.
   */
  private def advancedStreamed(): Boolean = {
    if (streamedIter.advanceNext()) {
      streamedRow = streamedIter.getRow
      streamedRowKey = streamedKeyGenerator(streamedRow)
      true
    } else {
      streamedRow = null
      streamedRowKey = null
      false
    }
  }
  /**
   * Advance the buffered iterator until we find a row with join key that does not contain nulls.
   * @return true if the buffered iterator returned a row and false otherwise.
   */
  private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
    var foundRow: Boolean = false
    while (!foundRow && bufferedIter.advanceNext()) {
      bufferedRow = bufferedIter.getRow
      bufferedRowKey = bufferedKeyGenerator(bufferedRow)
      foundRow = !bufferedRowKey.anyNull
    }
    if (!foundRow) {
      bufferedRow = null
      bufferedRowKey = null
      false
    } else {
      true
    }
  }
  /**
   * Called when the streamed and buffered join keys match in order to buffer the matching rows.
   */
  private def bufferMatchingRows(): Unit = {
    assert(streamedRowKey != null)
    assert(!streamedRowKey.anyNull)
    assert(bufferedRowKey != null)
    assert(!bufferedRowKey.anyNull)
    assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
    // This join key may have been produced by a mutable projection, so we need to make a copy:
    matchJoinKey = streamedRowKey.copy()
    bufferedMatches.clear()
    do {
      if (!onlyBufferFirstMatch || bufferedMatches.isEmpty) {
        bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow])
      }
      advancedBufferedToRowWithNullFreeJoinKey()
    } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
  }
}
/**
 * An iterator for outputting rows in left outer join.
 */
private class LeftOuterIterator(
    smjScanner: SortMergeJoinScanner,
    rightNullRow: InternalRow,
    boundCondition: InternalRow => Boolean,
    resultProj: InternalRow => InternalRow,
    numOutputRows: SQLMetric)
  extends OneSideOuterIterator(
    smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
}
/**
 * An iterator for outputting rows in right outer join.
 */
private class RightOuterIterator(
    smjScanner: SortMergeJoinScanner,
    leftNullRow: InternalRow,
    boundCondition: InternalRow => Boolean,
    resultProj: InternalRow => InternalRow,
    numOutputRows: SQLMetric)
  extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
}
/**
 * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
 *
 * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
 * streamed side will output 0 or many rows, one for each matching row on the buffered side.
 * If there are no matches, then the buffered side of the joined output will be a null row.
 *
 * In left outer join, the left is the streamed side and the right is the buffered side.
 * In right outer join, the right is the streamed side and the left is the buffered side.
 *
 * @param smjScanner a scanner that streams rows and buffers any matching rows
 * @param bufferedSideNullRow the default row to return when a streamed row has no matches
 * @param boundCondition an additional filter condition for buffered rows
 * @param resultProj how the output should be projected
 * @param numOutputRows an accumulator metric for the number of rows output
 */
private abstract class OneSideOuterIterator(
    smjScanner: SortMergeJoinScanner,
    bufferedSideNullRow: InternalRow,
    boundCondition: InternalRow => Boolean,
    resultProj: InternalRow => InternalRow,
    numOutputRows: SQLMetric) extends RowIterator {
  // A row to store the joined result, reused many times
  protected[this] val joinedRow: JoinedRow = new JoinedRow()
  // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
  private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
  // This iterator is initialized lazily so there should be no matches initially
  assert(smjScanner.getBufferedMatches.length == 0)
  // Set output methods to be overridden by subclasses
  protected def setStreamSideOutput(row: InternalRow): Unit
  protected def setBufferedSideOutput(row: InternalRow): Unit
  /**
   * Advance to the next row on the stream side and populate the buffer with matches.
   * @return whether there are more rows in the stream to consume.
   */
  private def advanceStream(): Boolean = {
    rightMatchesIterator = null
    if (smjScanner.findNextOuterJoinRows()) {
      setStreamSideOutput(smjScanner.getStreamedRow)
      if (smjScanner.getBufferedMatches.isEmpty) {
        // There are no matching rows in the buffer, so return the null row
        setBufferedSideOutput(bufferedSideNullRow)
      } else {
        // Find the next row in the buffer that satisfied the bound condition
        if (!advanceBufferUntilBoundConditionSatisfied()) {
          setBufferedSideOutput(bufferedSideNullRow)
        }
      }
      true
    } else {
      // Stream has been exhausted
      false
    }
  }
  /**
   * Advance to the next row in the buffer that satisfies the bound condition.
   * @return whether there is such a row in the current buffer.
   */
  private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
    var foundMatch: Boolean = false
    if (rightMatchesIterator == null) {
      rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator()
    }
    while (!foundMatch && rightMatchesIterator.hasNext) {
      setBufferedSideOutput(rightMatchesIterator.next())
      foundMatch = boundCondition(joinedRow)
    }
    foundMatch
  }
  override def advanceNext(): Boolean = {
    val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
    if (r) numOutputRows += 1
    r
  }
  override def getRow: InternalRow = resultProj(joinedRow)
}
private class SortMergeFullOuterJoinScanner(
    leftKeyGenerator: Projection,
    rightKeyGenerator: Projection,
    keyOrdering: Ordering[InternalRow],
    leftIter: RowIterator,
    rightIter: RowIterator,
    boundCondition: InternalRow => Boolean,
    leftNullRow: InternalRow,
    rightNullRow: InternalRow)  {
  private[this] val joinedRow: JoinedRow = new JoinedRow()
  private[this] var leftRow: InternalRow = _
  private[this] var leftRowKey: InternalRow = _
  private[this] var rightRow: InternalRow = _
  private[this] var rightRowKey: InternalRow = _
  private[this] var leftIndex: Int = 0
  private[this] var rightIndex: Int = 0
  private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
  private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
  private[this] var leftMatched: BitSet = new BitSet(1)
  private[this] var rightMatched: BitSet = new BitSet(1)
  advancedLeft()
  advancedRight()
  // --- Private methods --------------------------------------------------------------------------
  /**
   * Advance the left iterator and compute the new row's join key.
   * @return true if the left iterator returned a row and false otherwise.
   */
  private def advancedLeft(): Boolean = {
    if (leftIter.advanceNext()) {
      leftRow = leftIter.getRow
      leftRowKey = leftKeyGenerator(leftRow)
      true
    } else {
      leftRow = null
      leftRowKey = null
      false
    }
  }
  /**
   * Advance the right iterator and compute the new row's join key.
   * @return true if the right iterator returned a row and false otherwise.
   */
  private def advancedRight(): Boolean = {
    if (rightIter.advanceNext()) {
      rightRow = rightIter.getRow
      rightRowKey = rightKeyGenerator(rightRow)
      true
    } else {
      rightRow = null
      rightRowKey = null
      false
    }
  }
  /**
   * Populate the left and right buffers with rows matching the provided key.
   * This consumes rows from both iterators until their keys are different from the matching key.
   */
  private def findMatchingRows(matchingKey: InternalRow): Unit = {
    leftMatches.clear()
    rightMatches.clear()
    leftIndex = 0
    rightIndex = 0
    while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
      leftMatches += leftRow.copy()
      advancedLeft()
    }
    while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
      rightMatches += rightRow.copy()
      advancedRight()
    }
    if (leftMatches.size <= leftMatched.capacity) {
      leftMatched.clearUntil(leftMatches.size)
    } else {
      leftMatched = new BitSet(leftMatches.size)
    }
    if (rightMatches.size <= rightMatched.capacity) {
      rightMatched.clearUntil(rightMatches.size)
    } else {
      rightMatched = new BitSet(rightMatches.size)
    }
  }
  /**
   * Scan the left and right buffers for the next valid match.
   *
   * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
   * If a left row has no valid matches on the right, or a right row has no valid matches on the
   * left, then the row is joined with the null row and the result is considered a valid match.
   *
   * @return true if a valid match is found, false otherwise.
   */
  private def scanNextInBuffered(): Boolean = {
    while (leftIndex < leftMatches.size) {
      while (rightIndex < rightMatches.size) {
        joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
        if (boundCondition(joinedRow)) {
          leftMatched.set(leftIndex)
          rightMatched.set(rightIndex)
          rightIndex += 1
          return true
        }
        rightIndex += 1
      }
      rightIndex = 0
      if (!leftMatched.get(leftIndex)) {
        // the left row has never matched any right row, join it with null row
        joinedRow(leftMatches(leftIndex), rightNullRow)
        leftIndex += 1
        return true
      }
      leftIndex += 1
    }
    while (rightIndex < rightMatches.size) {
      if (!rightMatched.get(rightIndex)) {
        // the right row has never matched any left row, join it with null row
        joinedRow(leftNullRow, rightMatches(rightIndex))
        rightIndex += 1
        return true
      }
      rightIndex += 1
    }
    // There are no more valid matches in the left and right buffers
    false
  }
  // --- Public methods --------------------------------------------------------------------------
  def getJoinedRow(): JoinedRow = joinedRow
  def advanceNext(): Boolean = {
    // If we already buffered some matching rows, use them directly
    if (leftIndex < leftMatches.size || rightIndex < rightMatches.size) {
      if (scanNextInBuffered()) {
        return true
      }
    }
    if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
      joinedRow(leftRow.copy(), rightNullRow)
      advancedLeft()
      true
    } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
      joinedRow(leftNullRow, rightRow.copy())
      advancedRight()
      true
    } else if (leftRow != null && rightRow != null) {
      // Both rows are present and neither have null values.
      val comp = keyOrdering.compare(leftRowKey, rightRowKey)
      if (comp < 0) {
        joinedRow(leftRow.copy(), rightNullRow)
        advancedLeft()
      } else if (comp > 0) {
        joinedRow(leftNullRow, rightRow.copy())
        advancedRight()
      } else {
        // Populate the buffers with rows matching the next key.
        findMatchingRows(leftRowKey.copy())
        scanNextInBuffered()
      }
      true
    } else {
      // Both iterators have been consumed
      false
    }
  }
}
private class FullOuterIterator(
    smjScanner: SortMergeFullOuterJoinScanner,
    resultProj: InternalRow => InternalRow,
    numRows: SQLMetric) extends RowIterator {
  private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
  override def advanceNext(): Boolean = {
    val r = smjScanner.advanceNext()
    if (r) numRows += 1
    r
  }
  override def getRow: InternalRow = resultProj(joinedRow)
}
相关信息
相关文章
spark BroadcastHashJoinExec 源码
                        
                            0
                        
                        
                             赞
                        
                    
                    
                - 所属分类: 前端技术
- 本文标签:
热门推荐
- 
                        2、 - 优质文章
- 
                        3、 gate.io
- 
                        8、 openharmony
- 
                        9、 golang