  2022-10-20
spark BarrierTaskContext 代码


package org.apache.spark

import java.util.{Properties, Timer, TimerTask}

import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.util.{Failure, Success => ScalaSuccess, Try}

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._

 * :: Experimental ::
 * A [[TaskContext]] with extra contextual info and tooling for tasks in a barrier stage.
 * Use [[BarrierTaskContext#get]] to obtain the barrier context for a running barrier task.
class BarrierTaskContext private[spark] (
    taskContext: TaskContext) extends TaskContext with Logging {

  import BarrierTaskContext._

  // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls.
  private val barrierCoordinator: RpcEndpointRef = {
    val env = SparkEnv.get
    RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv)

  // Local barrierEpoch that identify a barrier() call from current task, it shall be identical
  // with the driver side epoch.
  private var barrierEpoch = 0

  private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = {
    logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
      s"the global sync, current barrier epoch is $barrierEpoch.")
    logTrace("Current callSite: " + Utils.getCallSite())

    val startTime = System.currentTimeMillis()
    val timerTask = new TimerTask {
      override def run(): Unit = {
        logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
          s"under the global sync since $startTime, has been waiting for " +
          s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
          s"current barrier epoch is $barrierEpoch.")
    // Log the update of global sync every 60 seconds.
    timer.schedule(timerTask, 60000, 60000)

    try {
      val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
        message = RequestToSync(numPartitions, stageId, stageAttemptNumber, taskAttemptId,
          barrierEpoch, partitionId, message, requestMethod),
        // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
        // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
        timeout = new RpcTimeout(365.days, "barrierTimeout"))

      // Wait the RPC future to be completed, but every 1 second it will jump out waiting
      // and check whether current spark task is killed. If killed, then throw
      // a `TaskKilledException`, otherwise continue wait RPC until it completes.

      while (!abortableRpcFuture.future.isCompleted) {
        try {
          // wait RPC future for at most 1 second
        } catch {
          case _: InterruptedException => // task is killed by driver
        } finally {
          Try(taskContext.killTaskIfInterrupted()) match {
            case ScalaSuccess(_) => // task is still running healthily
            case Failure(e) => abortableRpcFuture.abort(e)
      // messages which consist of all barrier tasks' messages. The future will return the
      // desired messages if it is completed successfully. Otherwise, exception could be thrown.
      val messages = abortableRpcFuture.future.value.get.get

      barrierEpoch += 1
      logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
        "global sync successfully, waited for " +
        s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
        s"current barrier epoch is $barrierEpoch.")
    } catch {
      case e: SparkException =>
        logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
          "to perform global sync, waited for " +
          s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
          s"current barrier epoch is $barrierEpoch.")
        throw e
    } finally {

   * :: Experimental ::
   * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
   * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
   * stage have reached this routine.
   * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
   * possible code branches. Otherwise, you may get the job hanging or a SparkException after
   * timeout. Some examples of '''misuses''' are listed below:
   * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
   * shall lead to timeout of the function call.
   * {{{
   *   rdd.barrier().mapPartitions { iter =>
   *       val context = BarrierTaskContext.get()
   *       if (context.partitionId() == 0) {
   *           // Do nothing.
   *       } else {
   *           context.barrier()
   *       }
   *       iter
   *   }
   * }}}
   * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
   * second function call.
   * {{{
   *   rdd.barrier().mapPartitions { iter =>
   *       val context = BarrierTaskContext.get()
   *       try {
   *           // Do something that might throw an Exception.
   *           doSomething()
   *           context.barrier()
   *       } catch {
   *           case e: Exception => logWarning("...", e)
   *       }
   *       context.barrier()
   *       iter
   *   }
   * }}}
  def barrier(): Unit = runBarrier("", RequestMethod.BARRIER)

   * :: Experimental ::
   * Blocks until all tasks in the same stage have reached this routine. Each task passes in
   * a message and returns with a list of all the messages passed in by each of those tasks.
   * CAUTION! The allGather method requires the same precautions as the barrier method
   * The message is type String rather than Array[Byte] because it is more convenient for
   * the user at the cost of worse performance.
  def allGather(message: String): Array[String] = runBarrier(message, RequestMethod.ALL_GATHER)

   * :: Experimental ::
   * Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered by partition ID.
  def getTaskInfos(): Array[BarrierTaskInfo] = {
    val addressesStr = Option(taskContext.getLocalProperty("addresses")).getOrElse("")
    addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))

  // delegate methods

  override def isCompleted(): Boolean = taskContext.isCompleted()

  override def isInterrupted(): Boolean = taskContext.isInterrupted()

  override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {

  override def addTaskFailureListener(listener: TaskFailureListener): this.type = {

  override def stageId(): Int = taskContext.stageId()

  override def stageAttemptNumber(): Int = taskContext.stageAttemptNumber()

  override def partitionId(): Int = taskContext.partitionId()

  override def numPartitions(): Int = taskContext.numPartitions()

  override def attemptNumber(): Int = taskContext.attemptNumber()

  override def taskAttemptId(): Long = taskContext.taskAttemptId()

  override def getLocalProperty(key: String): String = taskContext.getLocalProperty(key)

  override def taskMetrics(): TaskMetrics = taskContext.taskMetrics()

  override def getMetricsSources(sourceName: String): Seq[Source] = {

  override def cpus(): Int = taskContext.cpus()

  override def resources(): Map[String, ResourceInformation] = taskContext.resources()

  override def resourcesJMap(): java.util.Map[String, ResourceInformation] = {

  override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted()

  override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason()

  override private[spark] def taskMemoryManager(): TaskMemoryManager = {

  override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {

  override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = {

  override private[spark] def markInterrupted(reason: String): Unit = {

  override private[spark] def markTaskFailed(error: Throwable): Unit = {

  override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = {

  override private[spark] def fetchFailed: Option[FetchFailedException] = {

  override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties

object BarrierTaskContext {
   * :: Experimental ::
   * Returns the currently active BarrierTaskContext. This can be called inside of user functions to
   * access contextual information about running barrier tasks.
  def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]

  private val timer = new Timer("Barrier task timer for barrier() calls.")



