spark EvaluatePython 源码

  • 2022-10-20
  • 浏览 (220)

spark EvaluatePython 代码

文件路径:/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.python

import java.io.OutputStream
import java.nio.charset.StandardCharsets

import scala.collection.JavaConverters._

import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler}

import org.apache.spark.{ContextAwareIterator, TaskContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object EvaluatePython {

  def needConversionInPython(dt: DataType): Boolean = dt match {
    case DateType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => true
    case _: StructType => true
    case _: UserDefinedType[_] => true
    case ArrayType(elementType, _) => needConversionInPython(elementType)
    case MapType(keyType, valueType, _) =>
      needConversionInPython(keyType) || needConversionInPython(valueType)
    case _ => false
  }

  /**
   * Helper for converting from Catalyst type to java type suitable for Pickle.
   */
  def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
    case (null, _) => null

    case (row: InternalRow, struct: StructType) =>
      val values = new Array[Any](row.numFields)
      var i = 0
      while (i < row.numFields) {
        values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType)
        i += 1
      }
      new GenericRowWithSchema(values, struct)

    case (a: ArrayData, array: ArrayType) =>
      val values = new java.util.ArrayList[Any](a.numElements())
      a.foreach(array.elementType, (_, e) => {
        values.add(toJava(e, array.elementType))
      })
      values

    case (map: MapData, mt: MapType) =>
      val jmap = new java.util.HashMap[Any, Any](map.numElements())
      map.foreach(mt.keyType, mt.valueType, (k, v) => {
        jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType))
      })
      jmap

    case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)

    case (d: Decimal, _) => d.toJavaBigDecimal

    case (s: UTF8String, StringType) => s.toString

    case (other, _) => other
  }

  /**
   * Make a converter that converts `obj` to the type specified by the data type, or returns
   * null if the type of obj is unexpected. Because Python doesn't enforce the type.
   */
  def makeFromJava(dataType: DataType): Any => Any = dataType match {
    case BooleanType => (obj: Any) => nullSafeConvert(obj) {
      case b: Boolean => b
    }

    case ByteType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c
      case c: Short => c.toByte
      case c: Int => c.toByte
      case c: Long => c.toByte
    }

    case ShortType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c.toShort
      case c: Short => c
      case c: Int => c.toShort
      case c: Long => c.toShort
    }

    case IntegerType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c.toInt
      case c: Short => c.toInt
      case c: Int => c
      case c: Long => c.toInt
    }

    case LongType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c.toLong
      case c: Short => c.toLong
      case c: Int => c.toLong
      case c: Long => c
    }

    case FloatType => (obj: Any) => nullSafeConvert(obj) {
      case c: Float => c
      case c: Double => c.toFloat
    }

    case DoubleType => (obj: Any) => nullSafeConvert(obj) {
      case c: Float => c.toDouble
      case c: Double => c
    }

    case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) {
      case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale)
    }

    case DateType => (obj: Any) => nullSafeConvert(obj) {
      case c: Int => c
    }

    case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
      nullSafeConvert(obj) {
        case c: Long => c
        // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
        case c: Int => c.toLong
      }

    case StringType => (obj: Any) => nullSafeConvert(obj) {
      case _ => UTF8String.fromString(obj.toString)
    }

    case BinaryType => (obj: Any) => nullSafeConvert(obj) {
      case c: String => c.getBytes(StandardCharsets.UTF_8)
      case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
    }

    case ArrayType(elementType, _) =>
      val elementFromJava = makeFromJava(elementType)

      (obj: Any) => nullSafeConvert(obj) {
        case c: java.util.List[_] =>
          new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray)
        case c if c.getClass.isArray =>
          new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e)))
      }

    case MapType(keyType, valueType, _) =>
      val keyFromJava = makeFromJava(keyType)
      val valueFromJava = makeFromJava(valueType)

      (obj: Any) => nullSafeConvert(obj) {
        case javaMap: java.util.Map[_, _] =>
          ArrayBasedMapData(
            javaMap,
            (key: Any) => keyFromJava(key),
            (value: Any) => valueFromJava(value))
      }

    case StructType(fields) =>
      val fieldsFromJava = fields.map(f => makeFromJava(f.dataType))

      (obj: Any) => nullSafeConvert(obj) {
        case c if c.getClass.isArray =>
          val array = c.asInstanceOf[Array[_]]
          if (array.length != fields.length) {
            throw new IllegalStateException(
              s"Input row doesn't have expected number of values required by the schema. " +
                s"${fields.length} fields are required while ${array.length} values are provided."
            )
          }

          val row = new GenericInternalRow(fields.length)
          var i = 0
          while (i < fields.length) {
            row(i) = fieldsFromJava(i)(array(i))
            i += 1
          }
          row
      }

    case udt: UserDefinedType[_] => makeFromJava(udt.sqlType)

    case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty)
  }

  private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = {
    if (input == null) {
      null
    } else {
      f.applyOrElse(input, {
        // all other unexpected type should be null, or we will have runtime exception
        // TODO(davies): we could improve this by try to cast the object to expected type
        _: Any => null
      })
    }
  }

  private val module = "pyspark.sql.types"

  /**
   * Pickler for StructType
   */
  private class StructTypePickler extends IObjectPickler {

    private val cls = classOf[StructType]

    def register(): Unit = {
      Pickler.registerCustomPickler(cls, this)
    }

    def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      out.write(Opcodes.GLOBAL)
      out.write(
        (module + "\n" + "_parse_datatype_json_string" + "\n").getBytes(StandardCharsets.UTF_8))
      val schema = obj.asInstanceOf[StructType]
      pickler.save(schema.json)
      out.write(Opcodes.TUPLE1)
      out.write(Opcodes.REDUCE)
    }
  }

  /**
   * Pickler for external row.
   */
  private class RowPickler extends IObjectPickler {

    private val cls = classOf[GenericRowWithSchema]

    // register this to Pickler and Unpickler
    def register(): Unit = {
      Pickler.registerCustomPickler(this.getClass, this)
      Pickler.registerCustomPickler(cls, this)
    }

    def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      if (obj == this) {
        out.write(Opcodes.GLOBAL)
        out.write(
          (module + "\n" + "_create_row_inbound_converter" + "\n").getBytes(StandardCharsets.UTF_8))
      } else {
        // it will be memorized by Pickler to save some bytes
        pickler.save(this)
        val row = obj.asInstanceOf[GenericRowWithSchema]
        // schema should always be same object for memoization
        pickler.save(row.schema)
        out.write(Opcodes.TUPLE1)
        out.write(Opcodes.REDUCE)

        out.write(Opcodes.MARK)
        var i = 0
        while (i < row.values.length) {
          pickler.save(row.values(i))
          i += 1
        }
        out.write(Opcodes.TUPLE)
        out.write(Opcodes.REDUCE)
      }
    }
  }

  private[this] var registered = false

  /**
   * This should be called before trying to serialize any above classes un cluster mode,
   * this should be put in the closure
   */
  def registerPicklers(): Unit = {
    synchronized {
      if (!registered) {
        SerDeUtil.initialize()
        new StructTypePickler().register()
        new RowPickler().register()
        registered = true
      }
    }
  }

  /**
   * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
   * PySpark.
   */
  def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = {
    rdd.mapPartitions { iter =>
      registerPicklers()  // let it called in executor
      new SerDeUtil.AutoBatchedPickler(new ContextAwareIterator(TaskContext.get, iter))
    }
  }
}

相关信息

spark 源码目录

相关文章

spark AggregateInPandasExec 源码

spark ApplyInPandasWithStatePythonRunner 源码

spark ApplyInPandasWithStateWriter 源码

spark ArrowEvalPythonExec 源码

spark ArrowPythonRunner 源码

spark AttachDistributedSequenceExec 源码

spark BatchEvalPythonExec 源码

spark CoGroupedArrowPythonRunner 源码

spark EvalPythonExec 源码

spark ExtractPythonUDFs 源码

0  赞