spark ExtractPythonUDFs 源码

  • 2022-10-20
spark ExtractPythonUDFs 代码


 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

package org.apache.spark.sql.execution.python

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._

 * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
 * grouping key, or doesn't depend on any above expressions, evaluate them after aggregate.
object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {

   * Returns whether the expression could only be evaluated within aggregate.
  private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
    e.isInstanceOf[AggregateExpression] ||
      PythonUDF.isGroupedAggPandasUDF(e) ||

  private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
    expr.exists {
      e => PythonUDF.isScalarPythonUDF(e) &&
        (e.references.isEmpty || e.exists(belongAggregate(_, agg)))

  private def extract(agg: Aggregate): LogicalPlan = {
    val projList = new ArrayBuffer[NamedExpression]()
    val aggExpr = new ArrayBuffer[NamedExpression]()
    agg.aggregateExpressions.foreach { expr =>
      if (hasPythonUdfOverAggregate(expr, agg)) {
        // Python UDF can only be evaluated after aggregate
        val newE = expr transformDown {
          case e: Expression if belongAggregate(e, agg) =>
            val alias = e match {
              case a: NamedExpression => a
              case o => Alias(e, "agg")()
            aggExpr += alias
        projList += newE.asInstanceOf[NamedExpression]
      } else {
        aggExpr += expr
        projList += expr.toAttribute
    // There is no Python UDF over aggregate expression
    Project(projList.toSeq, agg.copy(aggregateExpressions = aggExpr.toSeq))

  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
    _.containsAllPatterns(PYTHON_UDF, AGGREGATE)) {
    case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>

 * Extracts PythonUDFs in logical aggregate, which are used in grouping keys, evaluate them
 * before aggregate.
 * This must be executed after `ExtractPythonUDFFromAggregate` rule and before `ExtractPythonUDFs`.
object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
  private def hasScalarPythonUDF(e: Expression): Boolean = {

  private def extract(agg: Aggregate): LogicalPlan = {
    val projList = new ArrayBuffer[NamedExpression]()
    val groupingExpr = new ArrayBuffer[Expression]()
    val attributeMap = mutable.HashMap[PythonUDF, NamedExpression]()

    agg.groupingExpressions.foreach { expr =>
      if (hasScalarPythonUDF(expr)) {
        val newE = expr transformDown {
          case p: PythonUDF =>
            // This is just a sanity check, the rule PullOutNondeterministic should
            // already pull out those nondeterministic expressions.
            assert(p.udfDeterministic, "Non-deterministic PythonUDFs should not appear " +
              "in grouping expression")
            val canonicalized = p.canonicalized.asInstanceOf[PythonUDF]
            if (attributeMap.contains(canonicalized)) {
            } else {
              val alias = Alias(p, "groupingPythonUDF")()
              projList += alias
              attributeMap += ((canonicalized, alias.toAttribute))
        groupingExpr += newE
      } else {
        groupingExpr += expr
    val aggExpr = { expr =>
      expr.transformUp {
        // PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate.
        // PythonUDF here should be either
        // 1. Argument of an aggregate function.
        //    CheckAnalysis guarantees the arguments are deterministic.
        // 2. PythonUDF in grouping key. Grouping key must be deterministic.
        // 3. PythonUDF not in grouping key. It is either no arguments or with grouping key
        // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too.
        case p: PythonUDF if p.udfDeterministic =>
          val canonicalized = p.canonicalized.asInstanceOf[PythonUDF]
          attributeMap.getOrElse(canonicalized, p)
      groupingExpressions = groupingExpr.toSeq,
      aggregateExpressions = aggExpr,
      child = Project((projList ++ agg.child.output).toSeq, agg.child))

  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
    _.containsAllPatterns(PYTHON_UDF, AGGREGATE)) {
    case agg: Aggregate if agg.groupingExpressions.exists(hasScalarPythonUDF(_)) =>

 * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
 * alone in a batch.
 * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs
 * or all the children could be evaluated in JVM).
 * This has the limitation that the input to the Python UDF is not allowed include attributes from
 * multiple child operators.
object ExtractPythonUDFs extends Rule[LogicalPlan] {

  private type EvalType = Int
  private type EvalTypeChecker = EvalType => Boolean

  private def hasScalarPythonUDF(e: Expression): Boolean = {

  private def canEvaluateInPython(e: PythonUDF): Boolean = {
    e.children match {
      // single PythonUDF child could be chained and evaluated in Python
      case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
      // Python UDF can't be evaluated directly in JVM
      case children => !children.exists(hasScalarPythonUDF)

  private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
    // If first UDF is SQL_SCALAR_PANDAS_ITER_UDF, then only return this UDF,
    // otherwise check if subsequent UDFs are of the same type as the first UDF. (since we can only
    // extract UDFs of the same eval type)

    var firstVisitedScalarUDFEvalType: Option[Int] = None

    def canChainUDF(evalType: Int): Boolean = {
      if (evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF) {
      } else {
        evalType == firstVisitedScalarUDFEvalType.get

    def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
      case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
        && firstVisitedScalarUDFEvalType.isEmpty =>
        firstVisitedScalarUDFEvalType = Some(udf.evalType)
      case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
        && canChainUDF(udf.evalType) =>
      case e => e.children.flatMap(collectEvaluableUDFs)


  def apply(plan: LogicalPlan): LogicalPlan = plan match {
    // SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
    // eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
    case s: Subquery if s.correlated => plan

    case _ => plan.transformUpWithPruning(
      // All cases must contain pattern PYTHON_UDF. PythonUDFs are member fields of BatchEvalPython
      // and ArrowEvalPython.
      _.containsPattern(PYTHON_UDF)) {
      // A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
      // `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
      // extract Python UDFs from them.
      case p: BatchEvalPython => p
      case p: ArrowEvalPython => p

      case plan: LogicalPlan => extract(plan)

  private def canonicalizeDeterministic(u: PythonUDF) = {
    if (u.deterministic) {
    } else {

   * Extract all the PythonUDFs from the current operator and evaluate them before the operator.
  private def extract(plan: LogicalPlan): LogicalPlan = {
    val udfs = ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions))
      // ignore the PythonUDF that come from second/third aggregate, which is not used
      .filter(udf => udf.references.subsetOf(plan.inputSet))
    if (udfs.isEmpty) {
      // If there aren't any, we are done.
    } else {
      val attributeMap = mutable.HashMap[PythonUDF, Expression]()
      // Rewrite the child that has the input required for the UDF
      val newChildren = { child =>
        // Pick the UDF we are going to evaluate
        val validUdfs = udfs.filter { udf =>
          // Check to make sure that the UDF can be evaluated with only the input of this child.
        if (validUdfs.nonEmpty) {
            "Can only extract scalar vectorized udf or sql batch udf")

          val resultAttrs = { case (u, i) =>
            AttributeReference(s"pythonUDF$i", u.dataType)()

          val evalTypes =
          if (evalTypes.size != 1) {
            throw new IllegalStateException(
              "Expected udfs have the same evalType but got different evalTypes: " +
          val evalType = evalTypes.head
          val evaluation = evalType match {
            case PythonEvalType.SQL_BATCHED_UDF =>
              BatchEvalPython(validUdfs, resultAttrs, child)
            case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF =>
              ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
            case _ =>
              throw new IllegalStateException("Unexpected UDF evalType")

          attributeMap ++=
        } else {
      // Other cases are disallowed as they are ambiguous or would require a cartesian
      // product. { udf =>
        throw new IllegalStateException(
          s"Invalid PythonUDF $udf, requires attributes from more than one child.")

      val rewritten = plan.withNewChildren(newChildren).transformExpressions {
        case p: PythonUDF => attributeMap.getOrElse(canonicalizeDeterministic(p), p)

      // extract remaining python UDFs recursively
      val newPlan = extract(rewritten)
      if (newPlan.output != plan.output) {
        // Trim away the new UDF value if it was only used for filtering or something.
        Project(plan.output, newPlan)
      } else {


