spark RBackendHandler 源码

  • 2022-10-20
spark RBackendHandler 代码


 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

package org.apache.spark.api.r

import{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.TimeUnit

import{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.timeout.ReadTimeoutException

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.api.r.SerDe._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.R._
import org.apache.spark.util.{ThreadUtils, Utils}

 * Handler for RBackend
 * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use
 * this across connections ?
private[r] class RBackendHandler(server: RBackend)
  extends SimpleChannelInboundHandler[Array[Byte]] with Logging {

  override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
    val bis = new ByteArrayInputStream(msg)
    val dis = new DataInputStream(bis)

    val bos = new ByteArrayOutputStream()
    val dos = new DataOutputStream(bos)

    // First bit is isStatic
    val isStatic = readBoolean(dis)
    val objId = readString(dis)
    val methodName = readString(dis)
    val numArgs = readInt(dis)

    if (objId == "SparkRHandler") {
      methodName match {
        // This function is for test-purpose only
        case "echo" =>
          val args = readArgs(numArgs, dis)
          assert(numArgs == 1)

          writeInt(dos, 0)
          writeObject(dos, args(0), server.jvmObjectTracker)
        case "stopBackend" =>
          writeInt(dos, 0)
          writeType(dos, "void")
        case "rm" =>
          try {
            val t = readObjectType(dis)
            assert(t == 'c')
            val objToRemove = readString(dis)
            writeInt(dos, 0)
            writeObject(dos, null, server.jvmObjectTracker)
          } catch {
            case e: Exception =>
              logError(s"Removing $objId failed", e)
              writeInt(dos, -1)
              writeString(dos, s"Removing $objId failed: ${e.getMessage}")
        case _ =>
          writeString(dos, s"Error: unknown method $methodName")
    } else {
      // To avoid timeouts when reading results in SparkR driver, we will be regularly sending
      // heartbeat responses. We use special code +1 to signal the client that backend is
      // alive and it should continue blocking for result.
      val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread")
      val pingRunner = new Runnable {
        override def run(): Unit = {
          val pingBaos = new ByteArrayOutputStream()
          val pingDaos = new DataOutputStream(pingBaos)
          writeInt(pingDaos, +1)
      val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
      val heartBeatInterval = conf.get(R_HEARTBEAT_INTERVAL)
      val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT)
      val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1)

      execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS)
      handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
      execService.awaitTermination(1, TimeUnit.SECONDS)

    val reply = bos.toByteArray

  override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {

  override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
    cause match {
      case timeout: ReadTimeoutException =>
        // Do nothing. We don't want to timeout on read
        logWarning("Ignoring read timeout in RBackendHandler")
      case _ =>
        // Close the connection when an exception is raised.

  def handleMethodCall(
      isStatic: Boolean,
      objId: String,
      methodName: String,
      numArgs: Int,
      dis: DataInputStream,
      dos: DataOutputStream): Unit = {
    var obj: Object = null
    try {
      val cls = if (isStatic) {
      } else {
        obj = server.jvmObjectTracker(JVMObjectId(objId))

      val args = readArgs(numArgs, dis)

      val methods = cls.getMethods
      val selectedMethods = methods.filter(m => m.getName == methodName)
      if (selectedMethods.length > 0) {
        val index = findMatchedSignature(

        if (index.isEmpty) {
          logWarning(s"cannot find matching method ${cls}.$methodName. "
            + s"Candidates are:")
          selectedMethods.foreach { method =>
          throw new Exception(s"No matched method found for $cls.$methodName")

        val ret = selectedMethods(index.get).invoke(obj, args : _*)

        // Write status bit
        writeInt(dos, 0)
        writeObject(dos, ret, server.jvmObjectTracker)
      } else if (methodName == "<init>") {
        // methodName should be "<init>" for constructor
        val ctors = cls.getConstructors
        val index = findMatchedSignature(

        if (index.isEmpty) {
          logWarning(s"cannot find matching constructor for ${cls}. "
            + s"Candidates are:")
          ctors.foreach { ctor =>
          throw new Exception(s"No matched constructor found for $cls")

        val obj = ctors(index.get).newInstance(args : _*)

        writeInt(dos, 0)
        writeObject(dos, obj.asInstanceOf[AnyRef], server.jvmObjectTracker)
      } else {
        throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId)
    } catch {
      case e: Exception =>
        logError(s"$methodName on $objId failed", e)
        writeInt(dos, -1)
        // Writing the error message of the cause for the exception. This will be returned
        // to user in the R process.
        writeString(dos, Utils.exceptionString(e.getCause))

  // Read a number of arguments from the data input stream
  def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
    (0 until numArgs).map { _ =>
      readObject(dis, server.jvmObjectTracker)

  // Find a matching method signature in an array of signatures of constructors
  // or methods of the same name according to the passed arguments. Arguments
  // may be converted in order to match a signature.
  // Note that in Java reflection, constructors and normal methods are of different
  // classes, and share no parent class that provides methods for reflection uses.
  // There is no unified way to handle them in this function. So an array of signatures
  // is passed in instead of an array of candidate constructors or methods.
  // Returns an Option[Int] which is the index of the matched signature in the array.
  def findMatchedSignature(
      parameterTypesOfMethods: Array[Array[Class[_]]],
      args: Array[Object]): Option[Int] = {
    val numArgs = args.length

    for (index <- parameterTypesOfMethods.indices) {
      val parameterTypes = parameterTypesOfMethods(index)

      if (parameterTypes.length == numArgs) {
        var argMatched = true
        var i = 0
        while (i < numArgs && argMatched) {
          val parameterType = parameterTypes(i)

          if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
            // The case that the parameter type is a Scala Seq and the argument
            // is a Java array is considered matching. The array will be converted
            // to a Seq later if this method is matched.
          } else {
            var parameterWrapperType = parameterType

            // Convert native parameters to Object types as args is Array[Object] here
            if (parameterType.isPrimitive) {
              parameterWrapperType = parameterType match {
                case java.lang.Integer.TYPE => classOf[java.lang.Integer]
                case java.lang.Long.TYPE => classOf[java.lang.Integer]
                case java.lang.Double.TYPE => classOf[java.lang.Double]
                case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
                case _ => parameterType
            if ((parameterType.isPrimitive || args(i) != null) &&
                !parameterWrapperType.isInstance(args(i))) {
              argMatched = false

          i = i + 1

        if (argMatched) {
          // For now, we return the first matching method.
          // TODO: find best method in matching methods.

          // Convert args if needed
          val parameterTypes = parameterTypesOfMethods(index)

          for (i <- 0 until numArgs) {
            if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) {
              // Convert a Java array to scala Seq
              args(i) = args(i).asInstanceOf[Array[_]].toSeq

          return Some(index)


