spark AccumulatorV2 源码

  • 2022-10-20
spark AccumulatorV2 代码


 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

package org.apache.spark.util

import java.{lang => jl}
import java.util.ArrayList
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.AccumulatorContext.internOption

private[spark] case class AccumulatorMetadata(
    id: Long,
    name: Option[String],
    countFailedValues: Boolean) extends Serializable

 * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
 * type `OUT`.
 * `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely
 * (e.g., synchronized collections) because it will be read from other threads.
abstract class AccumulatorV2[IN, OUT] extends Serializable {
  private[spark] var metadata: AccumulatorMetadata = _
  private[this] var atDriverSide = true

  private[spark] def register(
      sc: SparkContext,
      name: Option[String] = None,
      countFailedValues: Boolean = false): Unit = {
    if (this.metadata != null) {
      throw new IllegalStateException("Cannot register an Accumulator twice.")
    this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)

   * Returns true if this accumulator has been registered.
   * @note All accumulators must be registered before use, or it will throw exception.
  final def isRegistered: Boolean =
    metadata != null && AccumulatorContext.get(

  private def assertMetadataNotNull(): Unit = {
    if (metadata == null) {
      throw new IllegalStateException("The metadata of this accumulator has not been assigned yet.")

   * Returns the id of this accumulator, can only be called after registration.
  final def id: Long = {

   * Returns the name of this accumulator, can only be called after registration.
  final def name: Option[String] = {

    if (atDriverSide) {
    } else {

   * Whether to accumulate values from failed tasks. This is set to true for system and time
   * metrics like serialization time or bytes spilled, and false for things with absolute values
   * like number of input rows.  This should be used for internal metrics only.
  private[spark] final def countFailedValues: Boolean = {

   * Creates an [[AccumulableInfo]] representation of this [[AccumulatorV2]] with the provided
   * values.
  private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
    val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))
    AccumulableInfo(id, name, internOption(update), internOption(value), isInternal,

  final private[spark] def isAtDriverSide: Boolean = atDriverSide

   * Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
   * value; for a list accumulator, Nil is zero value.
  def isZero: Boolean

   * Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy
   * must return true.
  def copyAndReset(): AccumulatorV2[IN, OUT] = {
    val copyAcc = copy()

   * Creates a new copy of this accumulator.
  def copy(): AccumulatorV2[IN, OUT]

   * Resets this accumulator, which is zero value. i.e. call `isZero` must
   * return true.
  def reset(): Unit

   * Takes the inputs and accumulates.
  def add(v: IN): Unit

   * Merges another same-type accumulator into this one and update its state, i.e. this should be
   * merge-in-place.
  def merge(other: AccumulatorV2[IN, OUT]): Unit

   * Defines the current value of this accumulator
  def value: OUT

  // Serialize the buffer of this accumulator before sending back this accumulator to the driver.
  // By default this method does nothing.
  protected def withBufferSerialized(): AccumulatorV2[IN, OUT] = this

  // Called by Java when serializing an object
  final protected def writeReplace(): Any = {
    if (atDriverSide) {
      if (!isRegistered) {
        throw new UnsupportedOperationException(
          "Accumulator must be registered before send to executor")
      val copyAcc = copyAndReset()
      assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
      val isInternalAcc = name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)
      if (isInternalAcc) {
        // Do not serialize the name of internal accumulator and send it to executor.
        copyAcc.metadata = metadata.copy(name = None)
      } else {
        // For non-internal accumulators, we still need to send the name because users may need to
        // access the accumulator name at executor side, or they may keep the accumulators sent from
        // executors and access the name when the registered accumulator is already garbage
        // collected(e.g. SQLMetrics).
        copyAcc.metadata = metadata
    } else {

  // Called by Java when deserializing an object
  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    if (atDriverSide) {
      atDriverSide = false

      // Automatically register the accumulator when it is deserialized with the task closure.
      // This is for external accumulators and internal ones that do not represent task level
      // metrics, e.g. internal SQL metrics, which are per-operator.
      val taskContext = TaskContext.get()
      if (taskContext != null) {
    } else {
      atDriverSide = true

  override def toString: String = {
    // getClass.getSimpleName can cause Malformed class name error,
    // call safer `Utils.getSimpleName` instead
    if (metadata == null) {
      "Un-registered Accumulator: " + Utils.getSimpleName(getClass)
    } else {
      Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)"

 * An internal class used to track accumulators by Spark itself.
private[spark] object AccumulatorContext extends Logging {

   * This global map holds the original accumulator objects that are created on the driver.
   * It keeps weak references to these objects so that accumulators can be garbage-collected
   * once the RDDs and user-code that reference them are cleaned up.
   * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
  private val originals = new ConcurrentHashMap[Long, jl.ref.WeakReference[AccumulatorV2[_, _]]]

  private[this] val nextId = new AtomicLong(0L)

  private[this] val someOfMinusOne = Some(-1L)
  private[this] val someOfZero = Some(0L)

   * Returns a globally unique ID for a new [[AccumulatorV2]].
   * Note: Once you copy the [[AccumulatorV2]] the ID is no longer unique.
  def newId(): Long = nextId.getAndIncrement

  /** Returns the number of accumulators registered. Used in testing. */
  def numAccums: Int = originals.size

   * Registers an [[AccumulatorV2]] created on the driver such that it can be used on the executors.
   * All accumulators registered here can later be used as a container for accumulating partial
   * values across multiple tasks. This is what `org.apache.spark.scheduler.DAGScheduler` does.
   * Note: if an accumulator is registered here, it should also be registered with the active
   * context cleaner for cleanup so as to avoid memory leaks.
   * If an [[AccumulatorV2]] with the same ID was already registered, this does nothing instead
   * of overwriting it. We will never register same accumulator twice, this is just a sanity check.
  def register(a: AccumulatorV2[_, _]): Unit = {
    originals.putIfAbsent(, new jl.ref.WeakReference[AccumulatorV2[_, _]](a))

   * Unregisters the [[AccumulatorV2]] with the given ID, if any.
  def remove(id: Long): Unit = {

   * Returns the [[AccumulatorV2]] registered with the given ID, if any.
  def get(id: Long): Option[AccumulatorV2[_, _]] = {
    val ref = originals.get(id)
    if (ref eq null) {
    } else {
      // Since we are storing weak references, warn when the underlying data is not valid.
      val acc = ref.get
      if (acc eq null) {
        logWarning(s"Attempted to access garbage collected accumulator $id")

   * Clears all registered [[AccumulatorV2]]s. For testing only.
  def clear(): Unit = {

  /** Naive way to reduce the duplicate Some objects for values 0 and -1
   *  TODO: Eventually if this spreads out to more values then using
   *  Guava's weak interner would be a better solution.
  def internOption(value: Option[Any]): Option[Any] = {
    value match {
      case Some(0L) => someOfZero
      case Some(-1L) => someOfMinusOne
      case _ => value

  // Identifier for distinguishing SQL metrics from other accumulators
  private[spark] val SQL_ACCUM_IDENTIFIER = "sql"

 * An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers.
 * @since 2.0.0
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
  private var _sum = 0L
  private var _count = 0L

   * Returns false if this accumulator has had any values added to it or the sum is non-zero.
   * @since 2.0.0
  override def isZero: Boolean = _sum == 0L && _count == 0

  override def copy(): LongAccumulator = {
    val newAcc = new LongAccumulator
    newAcc._count = this._count
    newAcc._sum = this._sum

  override def reset(): Unit = {
    _sum = 0L
    _count = 0L

   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
   * @since 2.0.0
  override def add(v: jl.Long): Unit = {
    _sum += v
    _count += 1

   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
   * @since 2.0.0
  def add(v: Long): Unit = {
    _sum += v
    _count += 1

   * Returns the number of elements added to the accumulator.
   * @since 2.0.0
  def count: Long = _count

   * Returns the sum of elements added to the accumulator.
   * @since 2.0.0
  def sum: Long = _sum

   * Returns the average of elements added to the accumulator.
   * @since 2.0.0
  def avg: Double = _sum.toDouble / _count

  override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
    case o: LongAccumulator =>
      _sum += o.sum
      _count += o.count
    case _ =>
      throw new UnsupportedOperationException(
        s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")

  private[spark] def setValue(newValue: Long): Unit = _sum = newValue

  override def value: jl.Long = _sum

 * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision
 * floating numbers.
 * @since 2.0.0
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
  private var _sum = 0.0
  private var _count = 0L

   * Returns false if this accumulator has had any values added to it or the sum is non-zero.
  override def isZero: Boolean = _sum == 0.0 && _count == 0

  override def copy(): DoubleAccumulator = {
    val newAcc = new DoubleAccumulator
    newAcc._count = this._count
    newAcc._sum = this._sum

  override def reset(): Unit = {
    _sum = 0.0
    _count = 0L

   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
   * @since 2.0.0
  override def add(v: jl.Double): Unit = {
    _sum += v
    _count += 1

   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
   * @since 2.0.0
  def add(v: Double): Unit = {
    _sum += v
    _count += 1

   * Returns the number of elements added to the accumulator.
   * @since 2.0.0
  def count: Long = _count

   * Returns the sum of elements added to the accumulator.
   * @since 2.0.0
  def sum: Double = _sum

   * Returns the average of elements added to the accumulator.
   * @since 2.0.0
  def avg: Double = _sum / _count

  override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
    case o: DoubleAccumulator =>
      _sum += o.sum
      _count += o.count
    case _ =>
      throw new UnsupportedOperationException(
        s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")

  private[spark] def setValue(newValue: Double): Unit = _sum = newValue

  override def value: jl.Double = _sum

 * An [[AccumulatorV2 accumulator]] for collecting a list of elements.
 * @since 2.0.0
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
  private var _list: java.util.List[T] = _

  private def getOrCreate = {
    _list = Option(_list).getOrElse(new java.util.ArrayList[T]())

   * Returns false if this accumulator instance has any values in it.
  override def isZero: Boolean = this.synchronized(getOrCreate.isEmpty)

  override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator

  override def copy(): CollectionAccumulator[T] = {
    val newAcc = new CollectionAccumulator[T]
    this.synchronized {

  override def reset(): Unit = this.synchronized {
    _list = null

  override def add(v: T): Unit = this.synchronized(getOrCreate.add(v))

  override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match {
    case o: CollectionAccumulator[T] => this.synchronized(getOrCreate.addAll(o.value))
    case _ => throw new UnsupportedOperationException(
      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")

  override def value: java.util.List[T] = this.synchronized {
    java.util.Collections.unmodifiableList(new ArrayList[T](getOrCreate))

  private[spark] def setValue(newValue: java.util.List[T]): Unit = this.synchronized {
    _list = null


