spark Executor 源码
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package org.apache.spark.executor
import{File, NotSerializableException}
import java.lang.Thread.UncaughtExceptionHandler
import{URI, URL}
import java.nio.ByteBuffer
import java.util.{Locale, Properties}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, Map, WrappedArray}
import scala.concurrent.duration._
import scala.util.control.NonFatal
import org.slf4j.MDC
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.internal.plugin.PluginContainer
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
* Spark executor, backed by a threadpool to run tasks.
* This can be used with Mesos, YARN, kubernetes and the standalone scheduler.
* An internal RPC interface is used for communication with the driver,
* except in the case of Mesos fine-grained mode.
private[spark] class Executor(
executorId: String,
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false,
uncaughtExceptionHandler: UncaughtExceptionHandler = new SparkUncaughtExceptionHandler,
resources: immutable.Map[String, ResourceInformation])
extends Logging {
logInfo(s"Starting executor ID $executorId on host $executorHostname")
private val executorShutdown = new AtomicBoolean(false)
val stopHookReference = ShutdownHookManager.addShutdownHook(
() => stop()
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
private val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
private val currentArchives: HashMap[String, Long] = new HashMap[String, Long]()
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
private[executor] val conf = env.conf
// SPARK-40235: updateDependencies() uses a ReentrantLock instead of the `synchronized` keyword
// so that tasks can exit quickly if they are interrupted while waiting on another task to
// finish downloading dependencies.
private val updateDependenciesLock = new ReentrantLock()
// No ip or host:port - just hostname
// must not have port specified.
assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
// Start worker thread pool
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
// will hang forever if some methods are interrupted.
private[executor] val threadPool = {
val threadFactory = new ThreadFactoryBuilder()
.setNameFormat("Executor task launch worker-%d")
.setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused"))
private val schemes = conf.get(EXECUTOR_METRICS_FILESYSTEM_SCHEMES)
private val executorSource = new ExecutorSource(threadPool, executorId, schemes)
// Pool used for threads that supervise task killing / cancellation
private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
// For tasks which are in the process of being killed, this map holds the most recently created
// TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't
// a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding
// the integrity of the map's internal state). The purpose of this map is to prevent the creation
// of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to
// track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise
// create. The map key is a task id.
private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]()
val executorMetricsSource =
Some(new ExecutorMetricsSource)
} else {
if (!isLocal) {
env.metricsSystem.registerSource(new JVMCPUSource())
} else {
// This enable the registration of the executor source in local mode.
// The actual registration happens in SparkContext,
// it cannot be done here as the appId is not available yet
Executor.executorSourceLocalModeOnly = executorSource
// Whether to load classes in user jars before those in Spark jars
private val userClassPathFirst = conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
// Whether to monitor killed / interrupted tasks
private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED)
private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH)
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
// Set the classloader for serializer
// SPARK-21928. SerializerManager's internal instance of Kryo might get used in netty threads
// for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too.
// Max size of direct result. If task result is bigger than this, we use the block manager
// to send the result back.
private val maxDirectResultSize = Math.min(
private val maxResultSize = conf.get(MAX_RESULT_SIZE)
// Maintains the list of running tasks.
private[executor] val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
// Kill mark TTL in milliseconds - 10 seconds.
private val KILL_MARK_TTL_MS = 10000L
// Kill marks with interruptThread flag, kill reason and timestamp.
// This is to avoid dropping the kill event when killTask() is called before launchTask().
private[executor] val killMarks = new ConcurrentHashMap[Long, (Boolean, String, Long)]
private val killMarkCleanupTask = new Runnable {
override def run(): Unit = {
val oldest = System.currentTimeMillis() - KILL_MARK_TTL_MS
val iter = killMarks.entrySet().iterator()
while (iter.hasNext) {
if ( < oldest) {
// Kill mark cleanup thread executor.
private val killMarkCleanupService =
* When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES`
* times, it should kill itself. The default value is 60. For example, if max failures is 60 and
* heartbeat interval is 10s, then it will try to send heartbeats for up to 600s (10 minutes).
* Whether to drop empty accumulators from heartbeats sent to the driver. Including the empty
* accumulators (that satisfy isZero) can make the size of the heartbeat message very large.
* Interval to send heartbeats, in milliseconds
* Interval to poll for executor metrics, in milliseconds
private val pollOnHeartbeat = if (METRICS_POLLING_INTERVAL_MS > 0) false else true
// Poller for the memory metrics. Visible for testing.
private[executor] val metricsPoller = new ExecutorMetricsPoller(
// Executor for the heartbeat task.
private val heartbeater = new Heartbeater(
() => Executor.this.reportHeartBeat(),
// must be initialized before running startDriverHeartbeat()
private val heartbeatReceiverRef =
RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
* Count the failure times of heartbeat. It should only be accessed in the heartbeat thread. Each
* successful heartbeat will reset it to 0.
private var heartbeatFailures = 0
* Flag to prevent launching new tasks while decommissioned. There could be a race condition
* accessing this, but decommissioning is only intended to help not be a hard stop.
private var decommissioned = false
private val appStartTime = conf.getLong("", 0)
// To allow users to distribute plugins and their required files
// specified by --jars, --files and --archives on application submission, those
// jars/files/archives should be downloaded and added to the class loader via
// updateDependencies. This should be done before plugin initialization below
// because executors search plugins from the class loader and initialize them.
private val Seq(initialUserJars, initialUserFiles, initialUserArchives) =
Seq("jar", "file", "archive").map { key =>
conf.getOption(s"$key.urls").map { urls =>
Map(urls.split(",").map(url => (url, appStartTime)): _*)
updateDependencies(initialUserFiles, initialUserJars, initialUserArchives)
// Plugins need to load using a class loader that includes the executor's user classpath.
// Plugins also needs to be initialized after the heartbeater started
// to avoid blocking to send heartbeat (see SPARK-32175).
private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) {
PluginContainer(env, resources.asJava)
private[executor] def numRunningTasks: Int = runningTasks.size()
* Mark an executor for decommissioning and avoid launching new tasks.
private[spark] def decommission(): Unit = {
decommissioned = true
private[executor] def createTaskRunner(context: ExecutorBackend,
taskDescription: TaskDescription) = new TaskRunner(context, taskDescription, plugins)
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val taskId = taskDescription.taskId
val tr = createTaskRunner(context, taskDescription)
runningTasks.put(taskId, tr)
val killMark = killMarks.get(taskId)
if (killMark != null) {
tr.kill(killMark._1, killMark._2)
if (decommissioned) {
log.error(s"Launching a task while in decommissioned state.")
def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
killMarks.put(taskId, (interruptThread, reason, System.currentTimeMillis()))
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized {
val shouldCreateReaper = taskReaperForTask.get(taskId) match {
case None => true
case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
if (shouldCreateReaper) {
val taskReaper = new TaskReaper(
taskRunner, interruptThread = interruptThread, reason = reason)
taskReaperForTask(taskId) = taskReaper
} else {
// Execute the TaskReaper from outside of the synchronized block.
} else {
taskRunner.kill(interruptThread = interruptThread, reason = reason)
// Safe to remove kill mark as we got a chance with the TaskRunner.
* Function to kill the running tasks in an executor.
* This can be called by executor back-ends to kill the
* tasks instead of taking the JVM down.
* @param interruptThread whether to interrupt the task thread
def killAllTasks(interruptThread: Boolean, reason: String) : Unit = {
runningTasks.keys().asScala.foreach(t =>
killTask(t, interruptThread = interruptThread, reason = reason))
def stop(): Unit = {
if (!executorShutdown.getAndSet(true)) {
try {
if (metricsPoller != null) {
} catch {
case NonFatal(e) =>
logWarning("Unable to stop executor metrics poller", e)
try {
if (heartbeater != null) {
} catch {
case NonFatal(e) =>
logWarning("Unable to stop heartbeater", e)
if (threadPool != null) {
if (killMarkCleanupService != null) {
if (replClassLoader != null && plugins != null) {
// Notify plugins that executor is shutting down so they can terminate cleanly
Utils.withContextClassLoader(replClassLoader) {
if (!isLocal) {
/** Returns the total amount of time this JVM process has spent in garbage collection. */
private def computeTotalGcTime(): Long = {
class TaskRunner(
execBackend: ExecutorBackend,
private val taskDescription: TaskDescription,
private val plugins: Option[PluginContainer])
extends Runnable {
val taskId = taskDescription.taskId
val taskName =
val threadName = s"Executor task launch worker for $taskName"
val mdcProperties =
/** If specified, this task has been killed and this option contains the reason. */
@volatile private var reasonIfKilled: Option[String] = None
@volatile private var threadId: Long = -1
def getThreadId: Long = threadId
/** Whether this task has been finished. */
private var finished = false
def isFinished: Boolean = synchronized { finished }
/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
* The task to run. This will be set in run() by deserializing the task binary coming
* from the driver. Once it is set, it will never be changed.
@volatile var task: Task[Any] = _
def kill(interruptThread: Boolean, reason: String): Unit = {
logInfo(s"Executor is trying to kill $taskName, reason: $reason")
reasonIfKilled = Some(reason)
if (task != null) {
synchronized {
if (!finished) {
task.kill(interruptThread, reason)
* Set the finished flag to true and clear the current thread's interrupt status
private def setTaskFinishedAndClearInterruptStatus(): Unit = synchronized {
this.finished = true
// SPARK-14234 - Reset the interrupted status of the thread to avoid the
// ClosedByInterruptException during execBackend.statusUpdate which causes
// Executor to crash
// Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
// is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
// is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
* Utility function to:
* 1. Report executor runtime and JVM gc time if possible
* 2. Collect accumulator updates
* 3. Set the finished flag to true and clear current thread's interrupt status
private def collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs: Long) = {
// Report executor runtime and JVM gc time
Option(task).foreach(t => {
// SPARK-32898: it's possible that a task is killed when taskStartTimeNs has the initial
// value(=0) still. In this case, the executorRunTime should be considered as 0.
if (taskStartTimeNs > 0) System.nanoTime() - taskStartTimeNs else 0))
t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
// Collect latest accumulator values to report back to the driver
val accums: Seq[AccumulatorV2[_, _]] =
Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty)
val accUpdates = => acc.toInfo(Some(acc.value), None))
(accums, accUpdates)
override def run(): Unit = {
setMDCForTask(taskName, mdcProperties)
threadId = Thread.currentThread.getId
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTimeNs = System.nanoTime()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
} else 0L
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStartTimeNs: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
var taskStarted: Boolean = false
try {
// Must be set before updateDependencies() is called, in case fetching dependencies
// requires access to properties contained within (e.g. for access control).
taskDescription.addedFiles, taskDescription.addedJars, taskDescription.addedArchives)
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties =
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
val killReason = reasonIfKilled
if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException(killReason.get)
// The purpose of updating the epoch here is to invalidate executor map output status cache
// in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
// MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
// we don't need to make any special calls here.
if (!isLocal) {
logDebug(s"$taskName's epoch is ${task.epoch}")
metricsPoller.onTaskStart(taskId, task.stageId, task.stageAttemptId)
taskStarted = true
// Run the actual task and measure its runtime.
taskStartTimeNs = System.nanoTime()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
} else 0L
var threwException = true
val value = Utils.tryWithSafeFinally {
val res =
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem,
cpus = taskDescription.cpus,
resources = taskDescription.resources,
plugins = plugins)
threwException = false
} {
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0 && !threwException) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, $taskName"
throw new SparkException(errMsg)
} else {
if (releasedLocks.nonEmpty && !threwException) {
val errMsg =
s"${releasedLocks.size} block locks were not released by $taskName\n" +
releasedLocks.mkString("[", ", ", "]")
throw new SparkException(errMsg)
} else {
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"$taskName completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
val taskFinishNs = System.nanoTime()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
} else 0L
// If the task has been killed, let's fail it.
val resultSer = env.serializer.newInstance()
val beforeSerializationNs = System.nanoTime()
val valueBytes = resultSer.serialize(value)
val afterSerializationNs = System.nanoTime()
// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, deserializes the RDD and function to be run.
(taskStartTimeNs - deserializeStartTimeNs) + task.executorDeserializeTimeNs))
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
// We need to subtract's deserialization time to avoid double-counting
(taskFinishNs - taskStartTimeNs) - task.executorDeserializeTimeNs))
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
afterSerializationNs - beforeSerializationNs))
// Expose task metrics using the Dropwizard metrics system.
// Update task metrics counters
// Note: accumulator updates must be collected after TaskMetrics is updated
val accumUpdates = task.collectAccumulatorUpdates()
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId)
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit()
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName. Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize > maxDirectResultSize) {
val blockId = TaskResultBlockId(taskId)
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName, reason: ${t.reason}")
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
// Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
// without requiring a copy.
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName, reason: $killReason")
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnFatalErrorDepth) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
// and threw something else. Regardless, we treat it as a fetch failure.
val fetchFailedCls = classOf[FetchFailedException].getName
logWarning(s"$taskName encountered a ${fetchFailedCls} and " +
s"failed, but the ${fetchFailedCls} was hidden by another " +
s"exception. Spark is handling this like a fetch failure and ignoring the " +
s"other exception: $t")
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskCommitDeniedReason
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
case t: Throwable if env.isStopped =>
// Log the expected exception after executor.stop without stack traces
// see: SPARK-19147
logError(s"Exception in $taskName: ${t.getMessage}")
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName", t)
// SPARK-20904: Do not report failure to driver if if happened during shut down. Because
// libraries may set up shutdown hooks that race with running tasks during shutdown,
// spurious failures may occur and can result in improper accounting in the driver (e.g.
// the task failure would not be ignored if the shutdown happened because of preemption,
// instead of an app issue).
if (!ShutdownHookManager.inShutdown()) {
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val (taskFailureReason, serializedTaskFailureReason) = {
try {
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
(ef, ser.serialize(ef))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
(ef, ser.serialize(ef))
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
} else {
logInfo("Not reporting error to driver during JVM shutdown.")
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Executor.isFatalError(t, killOnFatalErrorDepth)) {
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
} finally {
cleanMDCForTask(taskName, mdcProperties)
if (taskStarted) {
// This means the task was successfully deserialized, its stageId and stageAttemptId
// are known, and metricsPoller.onTaskStart was called.
metricsPoller.onTaskCompletion(taskId, task.stageId, task.stageAttemptId)
private def hasFetchFailure: Boolean = {
task != null && task.context != null && task.context.fetchFailed.isDefined
private def setMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = {
try {
mdc.foreach { case (key, value) => MDC.put(key, value) }
// avoid overriding the takName by the user
MDC.put("mdc.taskName", taskName)
} catch {
case _: NoSuchFieldError => logInfo("MDC is not supported.")
private def cleanMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = {
try {
mdc.foreach { case (key, _) => MDC.remove(key) }
} catch {
case _: NoSuchFieldError => logInfo("MDC is not supported.")
* Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
* sending a Thread.interrupt(), and monitoring the task until it finishes.
* Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
* may not be interruptible or may not respond to their "killed" flags being set. If a significant
* fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
* remain running then this can lead to a situation where new jobs and tasks are starved of
* resources that are being used by these zombie tasks.
* The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie
* tasks. For backwards-compatibility / backportability this component is disabled by default
* and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
* A TaskReaper is created for a particular task when that task is killed / cancelled. Typically
* a task will have only one TaskReaper, but it's possible for a task to have up to two reapers
* in case kill is called twice with different values for the `interrupt` parameter.
* Once created, a TaskReaper will run until its supervised task has finished running. If the
* TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
* `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely
* if the supervised task never exits.
private class TaskReaper(
taskRunner: TaskRunner,
val interruptThread: Boolean,
val reason: String)
extends Runnable {
private[this] val taskId: Long = taskRunner.taskId
private[this] val killPollingIntervalMs: Long = conf.get(TASK_REAPER_POLLING_INTERVAL)
private[this] val killTimeoutNs: Long = {
private[this] val takeThreadDump: Boolean = conf.get(TASK_REAPER_THREAD_DUMP)
override def run(): Unit = {
setMDCForTask(taskRunner.taskName, taskRunner.mdcProperties)
val startTimeNs = System.nanoTime()
def elapsedTimeNs = System.nanoTime() - startTimeNs
def timeoutExceeded(): Boolean = killTimeoutNs > 0 && elapsedTimeNs > killTimeoutNs
try {
// Only attempt to kill the task once. If interruptThread = false then a second kill
// attempt would be a no-op and if interruptThread = true then it may not be safe or
// effective to interrupt multiple times:
taskRunner.kill(interruptThread = interruptThread, reason = reason)
// Monitor the killed task until it exits. The synchronization logic here is complicated
// because we don't want to synchronize on the taskRunner while possibly taking a thread
// dump, but we also need to be careful to avoid races between checking whether the task
// has finished and wait()ing for it to finish.
var finished: Boolean = false
while (!finished && !timeoutExceeded()) {
taskRunner.synchronized {
// We need to synchronize on the TaskRunner while checking whether the task has
// finished in order to avoid a race where the task is marked as finished right after
// we check and before we call wait().
if (taskRunner.isFinished) {
finished = true
} else {
if (taskRunner.isFinished) {
finished = true
} else {
val elapsedTimeMs = TimeUnit.NANOSECONDS.toMillis(elapsedTimeNs)
logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms")
if (takeThreadDump) {
try {
Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread =>
if (thread.threadName == taskRunner.threadName) {
logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}")
} catch {
case NonFatal(e) =>
logWarning("Exception thrown while obtaining thread dump: ", e)
if (!taskRunner.isFinished && timeoutExceeded()) {
val killTimeoutMs = TimeUnit.NANOSECONDS.toMillis(killTimeoutNs)
if (isLocal) {
logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " +
"not killing JVM because we are running in local mode.")
} else {
// In non-local-mode, the exception thrown here will bubble up to the uncaught exception
// handler and cause the executor JVM to exit.
throw new SparkException(
s"Killing executor JVM because killed task $taskId could not be stopped within " +
s"$killTimeoutMs ms.")
} finally {
cleanMDCForTask(taskRunner.taskName, taskRunner.mdcProperties)
// Clean up entries in the taskReaperForTask map.
taskReaperForTask.synchronized {
taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
if (taskReaperInMap eq this) {
} else {
// This must have been a TaskReaper where interruptThread == false where a subsequent
// killTask() call for the same task had interruptThread == true and overwrote the
// map entry.
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
private def createClassLoader(): MutableURLClassLoader = {
// Bootstrap the list of jars with the user class path.
val now = System.currentTimeMillis()
userClassPath.foreach { url =>
currentJars(url.getPath().split("/").last) = now
val currentLoader = Utils.getContextOrSparkClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = userClassPath.toArray ++ { uri =>
new File(uri.split("/").last).toURI.toURL
logInfo(s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " +
urls.mkString("'", ",", "'"))
if (userClassPathFirst) {
new ChildFirstURLClassLoader(urls, currentLoader)
} else {
new MutableURLClassLoader(urls, currentLoader)
* If the REPL is in use, add another ClassLoader that will read
* new classes defined by the REPL as the user types code
private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
val classUri = conf.get("spark.repl.class.uri", null)
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
try {
val _userClassPathFirst: java.lang.Boolean = userClassPathFirst
val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv],
classOf[String], classOf[ClassLoader], classOf[Boolean])
constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
} else {
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
* Visible for testing.
private[executor] def updateDependencies(
newFiles: Map[String, Long],
newJars: Map[String, Long],
newArchives: Map[String, Long],
testStartLatch: Option[CountDownLatch] = None,
testEndLatch: Option[CountDownLatch] = None): Unit = {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
try {
// For testing, so we can simulate a slow file download:
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
for ((name, timestamp) <- newArchives if currentArchives.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
val sourceURI = new URI(name)
val uriToDownload = UriBuilder.fromUri(sourceURI).fragment(null).build()
val source = Utils.fetchFile(uriToDownload.toString, Utils.createTempDir(), conf,
hadoopConf, timestamp, useCache = !isLocal, shouldUntar = false)
val dest = new File(
if (sourceURI.getFragment != null) sourceURI.getFragment else source.getName)
s"Unpacking an archive $name from ${source.getAbsolutePath} to ${dest.getAbsolutePath}")
Utils.unpack(source, dest)
currentArchives(name) = timestamp
for ((name, timestamp) <- newJars) {
val localName = new URI(name).getPath.split("/").last
val currentTimeStamp = currentJars.get(name)
if (currentTimeStamp < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
if (!urlClassLoader.getURLs().contains(url)) {
logInfo(s"Adding $url to class loader")
// For testing, so we can simulate a slow file download:
} finally {
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
// list of (task id, accumUpdates) to send back to the driver
val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulatorV2[_, _]])]()
val curGCTime = computeTotalGcTime()
if (pollOnHeartbeat) {
val executorUpdates = metricsPoller.getExecutorUpdates()
for (taskRunner <- runningTasks.values().asScala) {
if (taskRunner.task != null) {
taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
val accumulatorsToReport =
} else {
accumUpdates += ((taskRunner.taskId, accumulatorsToReport))
val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId,
try {
val response = heartbeatReceiverRef.askSync[HeartbeatResponse](
if (!executorShutdown.get && response.reregisterBlockManager) {
logInfo("Told to re-register on heartbeat")
heartbeatFailures = 0
} catch {
case NonFatal(e) =>
logWarning("Issue communicating with driver in heartbeater", e)
heartbeatFailures += 1
if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) {
logError(s"Exit as unable to send heartbeats to driver " +
s"more than $HEARTBEAT_MAX_FAILURES times")
private[spark] object Executor {
// This is reserved for internal use by components that need to read task properties before a
// task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
// used instead.
val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
// Used to store executorSource, for local mode only
var executorSourceLocalModeOnly: ExecutorSource = null
* Whether a `Throwable` thrown from a task is a fatal error. We will use this to decide whether
* to kill the executor.
* @param depthToCheck The max depth of the exception chain we should search for a fatal error. 0
* means not checking any fatal error (in other words, return false), 1 means
* checking only the exception but not the cause, and so on. This is to avoid
* `StackOverflowError` when hitting a cycle in the exception chain.
def isFatalError(t: Throwable, depthToCheck: Int): Boolean = {
if (depthToCheck <= 0) {
} else {
t match {
case _: SparkOutOfMemoryError => false
case e if Utils.isFatalError(e) => true
case e if e.getCause != null => isFatalError(e.getCause, depthToCheck - 1)
case _ => false
