spark MapOutputTracker 源码

  • 2022-10-20
  • 浏览 (362)

spark MapOutputTracker 代码

文件路径:/core/src/main/scala/org/apache/spark/MapOutputTracker.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark

import java.io.{ByteArrayInputStream, InputStream, IOException, ObjectInputStream, ObjectOutputStream}
import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap, ListBuffer, Map}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream}
import org.roaringbitmap.RoaringBitmap

import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
import org.apache.spark.util._
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

/**
 * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single
 * ShuffleMapStage.
 *
 * This class maintains a mapping from map index to `MapStatus`. It also maintains a cache of
 * serialized map statuses in order to speed up tasks' requests for map output statuses.
 *
 * All public methods of this class are thread-safe.
 */
private class ShuffleStatus(
    numPartitions: Int,
    numReducers: Int = -1) extends Logging {

  private val (readLock, writeLock) = {
    val lock = new ReentrantReadWriteLock()
    (lock.readLock(), lock.writeLock())
  }

  // All accesses to the following state must be guarded with `withReadLock` or `withWriteLock`.
  private def withReadLock[B](fn: => B): B = {
    readLock.lock()
    try {
      fn
    } finally {
      readLock.unlock()
    }
  }

  private def withWriteLock[B](fn: => B): B = {
    writeLock.lock()
    try {
      fn
    } finally {
      writeLock.unlock()
    }
  }

  /**
   * MapStatus for each partition. The index of the array is the map partition id.
   * Each value in the array is the MapStatus for a partition, or null if the partition
   * is not available. Even though in theory a task may run multiple times (due to speculation,
   * stage retries, etc.), in practice the likelihood of a map output being available at multiple
   * locations is so small that we choose to ignore that case and store only a single location
   * for each output.
   */
  // Exposed for testing
  val mapStatuses = new Array[MapStatus](numPartitions)

  /**
   * Keep the previous deleted MapStatus for recovery.
   */
  val mapStatusesDeleted = new Array[MapStatus](numPartitions)

  /**
   * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the
   * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for
   * a shuffle partition, or null if not available. When push-based shuffle is enabled, this array
   * provides a reducer oriented view of the shuffle status specifically for the results of
   * merging shuffle partition blocks into per-partition merged shuffle files.
   */
  val mergeStatuses = if (numReducers > 0) {
    new Array[MergeStatus](numReducers)
  } else {
    Array.empty[MergeStatus]
  }

  /**
   * The cached result of serializing the map statuses array. This cache is lazily populated when
   * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed.
   */
  private[this] var cachedSerializedMapStatus: Array[Byte] = _

  /**
   * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]]
   * serializes the map statuses array it may detect that the result is too large to send in a
   * single RPC, in which case it places the serialized array into a broadcast variable and then
   * sends a serialized broadcast variable instead. This variable holds a reference to that
   * broadcast variable in order to keep it from being garbage collected and to allow for it to be
   * explicitly destroyed later on when the ShuffleMapStage is garbage-collected.
   */
  private[spark] var cachedSerializedBroadcast: Broadcast[Array[Array[Byte]]] = _

  /**
   * Similar to cachedSerializedMapStatus and cachedSerializedBroadcast, but for MergeStatus.
   */
  private[this] var cachedSerializedMergeStatus: Array[Byte] = _

  private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Array[Byte]]] = _

  /**
   * Counter tracking the number of partitions that have output. This is a performance optimization
   * to avoid having to count the number of non-null entries in the `mapStatuses` array and should
   * be equivalent to`mapStatuses.count(_ ne null)`.
   */
  private[this] var _numAvailableMapOutputs: Int = 0

  /**
   * Counter tracking the number of MergeStatus results received so far from the shuffle services.
   */
  private[this] var _numAvailableMergeResults: Int = 0

  private[this] var shufflePushMergerLocations: Seq[BlockManagerId] = Seq.empty

  /**
   * Register a map output. If there is already a registered location for the map output then it
   * will be replaced by the new location.
   */
  def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
    if (mapStatuses(mapIndex) == null) {
      _numAvailableMapOutputs += 1
      invalidateSerializedMapOutputStatusCache()
    }
    mapStatuses(mapIndex) = status
  }

  /**
   * Update the map output location (e.g. during migration).
   */
  def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock {
    try {
      val mapStatusOpt = mapStatuses.find(x => x != null && x.mapId == mapId)
      mapStatusOpt match {
        case Some(mapStatus) =>
          logInfo(s"Updating map output for ${mapId} to ${bmAddress}")
          mapStatus.updateLocation(bmAddress)
          invalidateSerializedMapOutputStatusCache()
        case None =>
          val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId == mapId)
          if (index >= 0 && mapStatuses(index) == null) {
            val mapStatus = mapStatusesDeleted(index)
            mapStatus.updateLocation(bmAddress)
            mapStatuses(index) = mapStatus
            _numAvailableMapOutputs += 1
            invalidateSerializedMapOutputStatusCache()
            mapStatusesDeleted(index) = null
            logInfo(s"Recover ${mapStatus.mapId} ${mapStatus.location}")
          } else {
            logWarning(s"Asked to update map output ${mapId} for untracked map status.")
          }
      }
    } catch {
      case e: java.lang.NullPointerException =>
        logWarning(s"Unable to update map output for ${mapId}, status removed in-flight")
    }
  }

  /**
   * Remove the map output which was served by the specified block manager.
   * This is a no-op if there is no registered map output or if the registered output is from a
   * different block manager.
   */
  def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
    logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
    if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
      _numAvailableMapOutputs -= 1
      mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
      mapStatuses(mapIndex) = null
      invalidateSerializedMapOutputStatusCache()
    }
  }

  /**
   * Register a merge result.
   */
  def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock {
    if (mergeStatuses(reduceId) != status) {
      _numAvailableMergeResults += 1
      invalidateSerializedMergeOutputStatusCache()
    }
    mergeStatuses(reduceId) = status
  }

  def registerShuffleMergerLocations(shuffleMergers: Seq[BlockManagerId]): Unit = withWriteLock {
    if (shufflePushMergerLocations.isEmpty) {
      shufflePushMergerLocations = shuffleMergers
    }
  }

  def removeShuffleMergerLocations(): Unit = withWriteLock {
    shufflePushMergerLocations = Nil
  }

  // TODO support updateMergeResult for similar use cases as updateMapOutput

  /**
   * Remove the merge result which was served by the specified block manager.
   */
  def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
    if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) {
      _numAvailableMergeResults -= 1
      mergeStatuses(reduceId) = null
      invalidateSerializedMergeOutputStatusCache()
    }
  }

  /**
   * Removes all shuffle outputs associated with this host. Note that this will also remove
   * outputs which are served by an external shuffle server (if one exists).
   */
  def removeOutputsOnHost(host: String): Unit = withWriteLock {
    logDebug(s"Removing outputs for host ${host}")
    removeOutputsByFilter(x => x.host == host)
    removeMergeResultsByFilter(x => x.host == host)
  }

  /**
   * Removes all map outputs associated with the specified executor. Note that this will also
   * remove outputs which are served by an external shuffle server (if one exists), as they are
   * still registered with that execId.
   */
  def removeOutputsOnExecutor(execId: String): Unit = withWriteLock {
    logDebug(s"Removing outputs for execId ${execId}")
    removeOutputsByFilter(x => x.executorId == execId)
  }

  /**
   * Removes all shuffle outputs which satisfies the filter. Note that this will also
   * remove outputs which are served by an external shuffle server (if one exists).
   */
  def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
    for (mapIndex <- mapStatuses.indices) {
      if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
        _numAvailableMapOutputs -= 1
        mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
        mapStatuses(mapIndex) = null
        invalidateSerializedMapOutputStatusCache()
      }
    }
  }

  /**
   * Removes all shuffle merge result which satisfies the filter.
   */
  def removeMergeResultsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
    for (reduceId <- mergeStatuses.indices) {
      if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) {
        _numAvailableMergeResults -= 1
        mergeStatuses(reduceId) = null
        invalidateSerializedMergeOutputStatusCache()
      }
    }
  }

  /**
   * Number of partitions that have shuffle map outputs.
   */
  def numAvailableMapOutputs: Int = withReadLock {
    _numAvailableMapOutputs
  }

  /**
   * Number of shuffle partitions that have already been merge finalized when push-based
   * is enabled.
   */
  def numAvailableMergeResults: Int = withReadLock {
    _numAvailableMergeResults
  }

  /**
   * Returns the sequence of partition ids that are missing (i.e. needs to be computed).
   */
  def findMissingPartitions(): Seq[Int] = withReadLock {
    val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
    assert(missing.size == numPartitions - _numAvailableMapOutputs,
      s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}")
    missing
  }

  /**
   * Serializes the mapStatuses array into an efficient compressed format. See the comments on
   * `MapOutputTracker.serializeOutputStatuses()` for more details on the serialization format.
   *
   * This method is designed to be called multiple times and implements caching in order to speed
   * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
   * serialize the map statuses then serialization will only be performed in a single thread and
   * all other threads will block until the cache is populated.
   */
  def serializedMapStatus(
      broadcastManager: BroadcastManager,
      isLocal: Boolean,
      minBroadcastSize: Int,
      conf: SparkConf): Array[Byte] = {
    var result: Array[Byte] = null
    withReadLock {
      if (cachedSerializedMapStatus != null) {
        result = cachedSerializedMapStatus
      }
    }

    if (result == null) withWriteLock {
      if (cachedSerializedMapStatus == null) {
        val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus](
          mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf)
        cachedSerializedMapStatus = serResult._1
        cachedSerializedBroadcast = serResult._2
      }
      // The following line has to be outside if statement since it's possible that another thread
      // initializes cachedSerializedMapStatus in-between `withReadLock` and `withWriteLock`.
      result = cachedSerializedMapStatus
    }
    result
  }

  /**
   * Serializes the mapStatuses and mergeStatuses array into an efficient compressed format.
   * See the comments on `MapOutputTracker.serializeOutputStatuses()` for more details
   * on the serialization format.
   *
   * This method is designed to be called multiple times and implements caching in order to speed
   * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
   * serialize the statuses array then serialization will only be performed in a single thread and
   * all other threads will block until the cache is populated.
   */
  def serializedMapAndMergeStatus(
      broadcastManager: BroadcastManager,
      isLocal: Boolean,
      minBroadcastSize: Int,
      conf: SparkConf): (Array[Byte], Array[Byte]) = {
    val mapStatusesBytes: Array[Byte] =
      serializedMapStatus(broadcastManager, isLocal, minBroadcastSize, conf)
    var mergeStatusesBytes: Array[Byte] = null

    withReadLock {
      if (cachedSerializedMergeStatus != null) {
        mergeStatusesBytes = cachedSerializedMergeStatus
      }
    }

    if (mergeStatusesBytes == null) withWriteLock {
      if (cachedSerializedMergeStatus == null) {
        val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus](
          mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf)
        cachedSerializedMergeStatus = serResult._1
        cachedSerializedBroadcastMergeStatus = serResult._2
      }

      // The following line has to be outside if statement since it's possible that another
      // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and
      // `withWriteLock`.
      mergeStatusesBytes = cachedSerializedMergeStatus
    }
    (mapStatusesBytes, mergeStatusesBytes)
  }

  // Used in testing.
  def hasCachedSerializedBroadcast: Boolean = withReadLock {
    cachedSerializedBroadcast != null
  }

  /**
   * Helper function which provides thread-safe access to the mapStatuses array.
   * The function should NOT mutate the array.
   */
  def withMapStatuses[T](f: Array[MapStatus] => T): T = withReadLock {
    f(mapStatuses)
  }

  def withMergeStatuses[T](f: Array[MergeStatus] => T): T = withReadLock {
    f(mergeStatuses)
  }

  def getShufflePushMergerLocations: Seq[BlockManagerId] = withReadLock {
    shufflePushMergerLocations
  }

  /**
   * Clears the cached serialized map output statuses.
   */
  def invalidateSerializedMapOutputStatusCache(): Unit = withWriteLock {
    if (cachedSerializedBroadcast != null) {
      // Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444)
      Utils.tryLogNonFatalError {
        // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
        // RPCs to dead executors.
        cachedSerializedBroadcast.destroy()
      }
      cachedSerializedBroadcast = null
    }
    cachedSerializedMapStatus = null
  }

  /**
   * Clears the cached serialized merge result statuses.
   */
  def invalidateSerializedMergeOutputStatusCache(): Unit = withWriteLock {
    if (cachedSerializedBroadcastMergeStatus != null) {
      Utils.tryLogNonFatalError {
        // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
        // RPCs to dead executors.
        cachedSerializedBroadcastMergeStatus.destroy()
      }
      cachedSerializedBroadcastMergeStatus = null
    }
    cachedSerializedMergeStatus = null
  }
}

private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
  extends MapOutputTrackerMessage
private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int)
  extends MapOutputTrackerMessage
private[spark] case class GetShufflePushMergerLocations(shuffleId: Int)
  extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage

private[spark] sealed trait MapOutputTrackerMasterMessage
private[spark] case class GetMapOutputMessage(shuffleId: Int,
  context: RpcCallContext) extends MapOutputTrackerMasterMessage
private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int,
  context: RpcCallContext) extends MapOutputTrackerMasterMessage
private[spark] case class GetShufflePushMergersMessage(shuffleId: Int,
  context: RpcCallContext) extends MapOutputTrackerMasterMessage
private[spark] case class MapSizesByExecutorId(
  iter: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], enableBatchFetch: Boolean)

/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
    override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf)
  extends RpcEndpoint with Logging {

  logDebug("init") // force eager creation of logger

  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case GetMapOutputStatuses(shuffleId: Int) =>
      val hostPort = context.senderAddress.hostPort
      logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort")
      tracker.post(GetMapOutputMessage(shuffleId, context))

    case GetMapAndMergeResultStatuses(shuffleId: Int) =>
      val hostPort = context.senderAddress.hostPort
      logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort")
      tracker.post(GetMapAndMergeOutputMessage(shuffleId, context))

    case GetShufflePushMergerLocations(shuffleId: Int) =>
      logInfo(s"Asked to send shuffle push merger locations for shuffle" +
        s" $shuffleId to ${context.senderAddress.hostPort}")
      tracker.post(GetShufflePushMergersMessage(shuffleId, context))

    case StopMapOutputTracker =>
      logInfo("MapOutputTrackerMasterEndpoint stopped!")
      context.reply(true)
      stop()
  }
}

/**
 * Class that keeps track of the location of the map output of a stage. This is abstract because the
 * driver and executor have different versions of the MapOutputTracker. In principle the driver-
 * and executor-side classes don't need to share a common base class; the current shared base class
 * is maintained primarily for backwards-compatibility in order to avoid having to update existing
 * test code.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
  /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
  var trackerEndpoint: RpcEndpointRef = _

  /**
   * The driver-side counter is incremented every time that a map output is lost. This value is sent
   * to executors as part of tasks, where executors compare the new epoch number to the highest
   * epoch number that they received in the past. If the new epoch number is higher then executors
   * will clear their local caches of map output statuses and will re-fetch (possibly updated)
   * statuses from the driver.
   */
  protected var epoch: Long = 0
  protected val epochLock = new AnyRef

  /**
   * Send a message to the trackerEndpoint and get its result within a default timeout, or
   * throw a SparkException if this fails.
   */
  protected def askTracker[T: ClassTag](message: Any): T = {
    try {
      trackerEndpoint.askSync[T](message)
    } catch {
      case e: Exception =>
        logError("Error communicating with MapOutputTracker", e)
        throw new SparkException("Error communicating with MapOutputTracker", e)
    }
  }

  /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */
  protected def sendTracker(message: Any): Unit = {
    val response = askTracker[Boolean](message)
    if (response != true) {
      throw new SparkException(
        "Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
    }
  }

  // For testing
  def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
      : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    getMapSizesByExecutorId(shuffleId, 0, Int.MaxValue, reduceId, reduceId + 1)
  }

  // For testing
  def getPushBasedShuffleMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
      : MapSizesByExecutorId = {
    getPushBasedShuffleMapSizesByExecutorId(shuffleId, 0, Int.MaxValue, reduceId, reduceId + 1)
  }

  /**
   * Called from executors to get the server URIs and output sizes for each shuffle block that
   * needs to be read from a given range of map output partitions (startPartition is included but
   * endPartition is excluded from the range) within a range of mappers (startMapIndex is included
   * but endMapIndex is excluded) when push based shuffle is not enabled for the specific shuffle
   * dependency. If endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length
   * of total map outputs.
   *
   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
   *         and the second item is a sequence of (shuffle block id, shuffle block size, map index)
   *         tuples describing the shuffle blocks that are stored at that block manager.
   *         Note that zero-sized blocks are excluded in the result.
   */
  def getMapSizesByExecutorId(
      shuffleId: Int,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

  /**
   * Called from executors to get the server URIs and output sizes for each shuffle block that
   * needs to be read from a given range of map output partitions (startPartition is included but
   * endPartition is excluded from the range) within a range of mappers (startMapIndex is included
   * but endMapIndex is excluded) when push based shuffle is enabled for the specific shuffle
   * dependency. If endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length
   * of total map outputs.
   *
   * @return A case class object which includes two attributes. The first attribute is a sequence
   *         of 2-item tuples, where the first item in the tuple is a BlockManagerId, and the
   *         second item is a sequence of (shuffle block id, shuffle block size, map index) tuples
   *         tuples describing the shuffle blocks that are stored at that block manager. Note that
   *         zero-sized blocks are excluded in the result. The second attribute is a boolean flag,
   *         indicating whether batch fetch can be enabled.
   */
  def getPushBasedShuffleMapSizesByExecutorId(
      shuffleId: Int,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int): MapSizesByExecutorId

  /**
   * Called from executors upon fetch failure on an entire merged shuffle reduce partition.
   * Such failures can happen if the shuffle client fails to fetch the metadata for the given
   * merged shuffle partition. This method is to get the server URIs and output sizes for each
   * shuffle block that is merged in the specified merged shuffle block so fetch failure on a
   * merged shuffle block can fall back to fetching the unmerged blocks.
   *
   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
   *         and the second item is a sequence of (shuffle block ID, shuffle block size, map index)
   *         tuples describing the shuffle blocks that are stored at that block manager.
   */
  def getMapSizesForMergeResult(
      shuffleId: Int,
      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

  /**
   * Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is
   * to get the server URIs and output sizes for each shuffle block that is merged in the specified
   * merged shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back
   * to fetching the unmerged blocks.
   *
   * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is
   * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original
   * blocks part of this merged chunk.
   *
   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
   *         and the second item is a sequence of (shuffle block ID, shuffle block size, map index)
   *         tuples describing the shuffle blocks that are stored at that block manager.
   */
  def getMapSizesForMergeResult(
      shuffleId: Int,
      partitionId: Int,
      chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

  /**
   * Called from executors whenever a task with push based shuffle is enabled doesn't have shuffle
   * mergers available. This typically happens when the initial stages doesn't have enough shuffle
   * mergers available since very few executors got registered. This is on a best effort basis,
   * if there is not enough shuffle mergers available for this stage then an empty sequence would
   * be returned indicating the task to avoid shuffle push.
   * @param shuffleId
   */
  def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId]

  /**
   * Deletes map output status information for the specified shuffle stage.
   */
  def unregisterShuffle(shuffleId: Int): Unit

  def stop(): Unit = {}
}

/**
 * Driver-side class that keeps track of the location of the map output of a stage.
 *
 * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics
 * for performing locality-aware reduce task scheduling.
 *
 * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine
 * which tasks need to be run.
 */
private[spark] class MapOutputTrackerMaster(
    conf: SparkConf,
    private[spark] val broadcastManager: BroadcastManager,
    private[spark] val isLocal: Boolean)
  extends MapOutputTracker(conf) {

  // The size at which we use Broadcast to send the map output statuses to the executors
  private val minSizeForBroadcast = conf.get(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST).toInt

  /** Whether to compute locality preferences for reduce tasks */
  private val shuffleLocalityEnabled = conf.get(SHUFFLE_REDUCE_LOCALITY_ENABLE)

  // Number of map and reduce tasks above which we do not assign preferred locations based on map
  // output sizes. We limit the size of jobs for which assign preferred locations as computing the
  // top locations by size becomes expensive.
  private val SHUFFLE_PREF_MAP_THRESHOLD = 1000
  // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that
  private val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000

  // Fraction of total map output that must be at a location for it to considered as a preferred
  // location for a reduce task. Making this larger will focus on fewer locations where most data
  // can be read locally, but may lead to more delay in scheduling if those locations are busy.
  private val REDUCER_PREF_LOCS_FRACTION = 0.2

  // HashMap for storing shuffleStatuses in the driver.
  // Statuses are dropped only by explicit de-registering.
  // Exposed for testing
  val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala

  private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)

  // requests for MapOutputTrackerMasterMessages
  private val mapOutputTrackerMasterMessages =
    new LinkedBlockingQueue[MapOutputTrackerMasterMessage]

  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf, isDriver = true)

  // Thread pool used for handling map output status requests. This is a separate thread pool
  // to ensure we don't block the normal dispatcher threads.
  private val threadpool: ThreadPoolExecutor = {
    val numThreads = conf.get(SHUFFLE_MAPOUTPUT_DISPATCHER_NUM_THREADS)
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
    for (i <- 0 until numThreads) {
      pool.execute(new MessageLoop)
    }
    pool
  }

  // Make sure that we aren't going to exceed the max RPC message size by making sure
  // we use broadcast to send large map output statuses.
  if (minSizeForBroadcast > maxRpcMessageSize) {
    val msg = s"${SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST.key} ($minSizeForBroadcast bytes) " +
      s"must be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an " +
      "rpc message that is too large."
    logError(msg)
    throw new IllegalArgumentException(msg)
  }

  def post(message: MapOutputTrackerMasterMessage): Unit = {
    mapOutputTrackerMasterMessages.offer(message)
  }

  /** Message loop used for dispatching messages. */
  private class MessageLoop extends Runnable {
    private def handleStatusMessage(
        shuffleId: Int,
        context: RpcCallContext,
        needMergeOutput: Boolean): Unit = {
      val hostPort = context.senderAddress.hostPort
      val shuffleStatus = shuffleStatuses.get(shuffleId).head
      logDebug(s"Handling request to send ${if (needMergeOutput) "map" else "map/merge"}" +
        s" output locations for shuffle $shuffleId to $hostPort")
      if (needMergeOutput) {
        context.reply(
          shuffleStatus.
            serializedMapAndMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, conf))
      } else {
        context.reply(
          shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, conf))
      }
    }

    override def run(): Unit = {
      try {
        while (true) {
          try {
            val data = mapOutputTrackerMasterMessages.take()
            if (data == PoisonPill) {
              // Put PoisonPill back so that other MessageLoops can see it.
              mapOutputTrackerMasterMessages.offer(PoisonPill)
              return
            }

            data match {
              case GetMapOutputMessage(shuffleId, context) =>
                handleStatusMessage(shuffleId, context, false)
              case GetMapAndMergeOutputMessage(shuffleId, context) =>
                handleStatusMessage(shuffleId, context, true)
              case GetShufflePushMergersMessage(shuffleId, context) =>
                logDebug(s"Handling request to send shuffle push merger locations for shuffle" +
                  s" $shuffleId to ${context.senderAddress.hostPort}")
                context.reply(shuffleStatuses.get(shuffleId).map(_.getShufflePushMergerLocations)
                  .getOrElse(Seq.empty[BlockManagerId]))
            }
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case ie: InterruptedException => // exit
      }
    }
  }

  /** A poison endpoint that indicates MessageLoop should exit its message loop. */
  private val PoisonPill = GetMapOutputMessage(-99, null)

  // Used only in unit tests.
  private[spark] def getNumCachedSerializedBroadcast: Int = {
    shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
  }

  def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = {
    if (pushBasedShuffleEnabled) {
      if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) {
        throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
      }
    } else {
      if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
        throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
      }
    }
  }

  def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
    shuffleStatuses.get(shuffleId) match {
      case Some(shuffleStatus) =>
        shuffleStatus.updateMapOutput(mapId, bmAddress)
      case None =>
        logError(s"Asked to update map output for unknown shuffle ${shuffleId}")
    }
  }

  def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
    shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
  }

  /** Unregister map output information of the given shuffle, mapper and block manager */
  def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: BlockManagerId): Unit = {
    shuffleStatuses.get(shuffleId) match {
      case Some(shuffleStatus) =>
        shuffleStatus.removeMapOutput(mapIndex, bmAddress)
        incrementEpoch()
      case None =>
        throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
    }
  }

  /** Unregister all map and merge output information of the given shuffle. */
  def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = {
    shuffleStatuses.get(shuffleId) match {
      case Some(shuffleStatus) =>
        shuffleStatus.removeOutputsByFilter(x => true)
        shuffleStatus.removeMergeResultsByFilter(x => true)
        shuffleStatus.removeShuffleMergerLocations()
        incrementEpoch()
      case None =>
        throw new SparkException(
          s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID $shuffleId.")
    }
  }

  def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus): Unit = {
    shuffleStatuses(shuffleId).addMergeResult(reduceId, status)
  }

  def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = {
    statuses.foreach {
      case (reduceId, status) => registerMergeResult(shuffleId, reduceId, status)
    }
  }

  def registerShufflePushMergerLocations(
      shuffleId: Int,
      shuffleMergers: Seq[BlockManagerId]): Unit = {
    shuffleStatuses(shuffleId).registerShuffleMergerLocations(shuffleMergers)
  }

  /**
   * Unregisters a merge result corresponding to the reduceId if present. If the optional mapIndex
   * is specified, it will only unregister the merge result if the mapIndex is part of that merge
   * result.
   *
   * @param shuffleId the shuffleId.
   * @param reduceId  the reduceId.
   * @param bmAddress block manager address.
   * @param mapIndex  the optional mapIndex which should be checked to see it was part of the
   *                  merge result.
   */
  def unregisterMergeResult(
      shuffleId: Int,
      reduceId: Int,
      bmAddress: BlockManagerId,
      mapIndex: Option[Int] = None): Unit = {
    shuffleStatuses.get(shuffleId) match {
      case Some(shuffleStatus) =>
        val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
        if (mergeStatus != null &&
          (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
          shuffleStatus.removeMergeResult(reduceId, bmAddress)
          incrementEpoch()
        }
      case None =>
        throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID")
    }
  }

  def unregisterAllMergeResult(shuffleId: Int): Unit = {
    shuffleStatuses.get(shuffleId) match {
      case Some(shuffleStatus) =>
        shuffleStatus.removeMergeResultsByFilter(x => true)
        incrementEpoch()
      case None =>
        throw new SparkException(
          s"unregisterAllMergeResult called for nonexistent shuffle ID $shuffleId.")
    }
  }

  /** Unregister shuffle data */
  def unregisterShuffle(shuffleId: Int): Unit = {
    shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
      // SPARK-39553: Add protection for Scala 2.13 due to https://github.com/scala/bug/issues/12613
      // We should revert this if Scala 2.13 solves this issue.
      if (shuffleStatus != null) {
        shuffleStatus.invalidateSerializedMapOutputStatusCache()
        shuffleStatus.invalidateSerializedMergeOutputStatusCache()
      }
    }
  }

  /**
   * Removes all shuffle outputs associated with this host. Note that this will also remove
   * outputs which are served by an external shuffle server (if one exists).
   */
  def removeOutputsOnHost(host: String): Unit = {
    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) }
    incrementEpoch()
  }

  /**
   * Removes all shuffle outputs associated with this executor. Note that this will also remove
   * outputs which are served by an external shuffle server (if one exists), as they are still
   * registered with this execId.
   */
  def removeOutputsOnExecutor(execId: String): Unit = {
    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
    incrementEpoch()
  }

  /** Check if the given shuffle is being tracked */
  def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)

  def getNumAvailableOutputs(shuffleId: Int): Int = {
    shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0)
  }

  /** VisibleForTest. Invoked in test only. */
  private[spark] def getNumAvailableMergeResults(shuffleId: Int): Int = {
    shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0)
  }

  /**
   * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None
   * if the MapOutputTrackerMaster doesn't know about this shuffle.
   */
  def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = {
    shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
  }

  /**
   * Grouped function of Range, this is to avoid traverse of all elements of Range using
   * IterableLike's grouped function.
   */
  def rangeGrouped(range: Range, size: Int): Seq[Range] = {
    val start = range.start
    val step = range.step
    val end = range.end
    for (i <- start.until(end, size * step)) yield {
      i.until(i + size * step, step)
    }
  }

  /**
   * To equally divide n elements into m buckets, basically each bucket should have n/m elements,
   * for the remaining n%m elements, add one more element to the first n%m buckets each.
   */
  def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
    val elementsPerBucket = numElements / numBuckets
    val remaining = numElements % numBuckets
    val splitPoint = (elementsPerBucket + 1) * remaining
    if (elementsPerBucket == 0) {
      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
    } else {
      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
        rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
    }
  }

  /**
   * Return statistics about all of the outputs for a given shuffle.
   */
  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
    shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
      val parallelAggThreshold = conf.get(
        SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
      val parallelism = math.min(
        Runtime.getRuntime.availableProcessors(),
        statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt
      if (parallelism <= 1) {
        statuses.filter(_ != null).foreach { s =>
          for (i <- totalSizes.indices) {
            totalSizes(i) += s.getSizeForBlock(i)
          }
        }
      } else {
        val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
        try {
          implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
          val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map {
            reduceIds => Future {
              statuses.filter(_ != null).foreach { s =>
                reduceIds.foreach(i => totalSizes(i) += s.getSizeForBlock(i))
              }
            }
          }
          ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf)
        } finally {
          threadPool.shutdown()
        }
      }
      new MapOutputStatistics(dep.shuffleId, totalSizes)
    }
  }

  /**
   * Return the preferred hosts on which to run the given map output partition in a given shuffle,
   * i.e. the nodes that the most outputs for that partition are on. If the map output is
   * pre-merged, then return the node where the merged block is located if the merge ratio is
   * above the threshold.
   *
   * @param dep shuffle dependency object
   * @param partitionId map output partition that we want to read
   * @return a sequence of host names
   */
  def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int)
      : Seq[String] = {
    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
    if (shuffleStatus != null) {
      // Check if the map output is pre-merged and if the merge ratio is above the threshold.
      // If so, the location of the merged block is the preferred location.
      val preferredLoc = if (pushBasedShuffleEnabled) {
        shuffleStatus.withMergeStatuses { statuses =>
          val status = statuses(partitionId)
          val numMaps = dep.rdd.partitions.length
          if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps
            <= (1 - REDUCER_PREF_LOCS_FRACTION)) {
            Seq(status.location.host)
          } else {
            Nil
          }
        }
      } else {
        Nil
      }
      if (preferredLoc.nonEmpty) {
        preferredLoc
      } else {
        if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
          dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
          val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
            dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
          if (blockManagerIds.nonEmpty) {
            blockManagerIds.get.map(_.host)
          } else {
            Nil
          }
        } else {
          Nil
        }
      }
    } else {
      Nil
    }
  }

  /**
   * Return a list of locations that each have fraction of map output greater than the specified
   * threshold.
   *
   * @param shuffleId id of the shuffle
   * @param reducerId id of the reduce task
   * @param numReducers total number of reducers in the shuffle
   * @param fractionThreshold fraction of total map output size that a location must have
   *                          for it to be considered large.
   */
  def getLocationsWithLargestOutputs(
      shuffleId: Int,
      reducerId: Int,
      numReducers: Int,
      fractionThreshold: Double)
    : Option[Array[BlockManagerId]] = {

    val shuffleStatus = shuffleStatuses.get(shuffleId).orNull
    if (shuffleStatus != null) {
      shuffleStatus.withMapStatuses { statuses =>
        if (statuses.nonEmpty) {
          // HashMap to add up sizes of all blocks at the same location
          val locs = new HashMap[BlockManagerId, Long]
          var totalOutputSize = 0L
          var mapIdx = 0
          while (mapIdx < statuses.length) {
            val status = statuses(mapIdx)
            // status may be null here if we are called between registerShuffle, which creates an
            // array with null entries for each output, and registerMapOutputs, which populates it
            // with valid status entries. This is possible if one thread schedules a job which
            // depends on an RDD which is currently being computed by another thread.
            if (status != null) {
              val blockSize = status.getSizeForBlock(reducerId)
              if (blockSize > 0) {
                locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
                totalOutputSize += blockSize
              }
            }
            mapIdx = mapIdx + 1
          }
          val topLocs = locs.filter { case (loc, size) =>
            size.toDouble / totalOutputSize >= fractionThreshold
          }
          // Return if we have any locations which satisfy the required threshold
          if (topLocs.nonEmpty) {
            return Some(topLocs.keys.toArray)
          }
        }
      }
    }
    None
  }

  /**
   * Return the locations where the Mappers ran. The locations each includes both a host and an
   * executor id on that host.
   *
   * @param dep shuffle dependency object
   * @param startMapIndex the start map index
   * @param endMapIndex the end map index (exclusive)
   * @return a sequence of locations where task runs.
   */
  def getMapLocation(
      dep: ShuffleDependency[_, _, _],
      startMapIndex: Int,
      endMapIndex: Int): Seq[String] =
  {
    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
    if (shuffleStatus != null) {
      shuffleStatus.withMapStatuses { statuses =>
        if (startMapIndex < endMapIndex &&
          (startMapIndex >= 0 && endMapIndex <= statuses.length)) {
          val statusesPicked = statuses.slice(startMapIndex, endMapIndex).filter(_ != null)
          statusesPicked.map(_.location.host).toSeq
        } else {
          Nil
        }
      }
    } else {
      Nil
    }
  }

  def incrementEpoch(): Unit = {
    epochLock.synchronized {
      epoch += 1
      logDebug("Increasing epoch to " + epoch)
    }
  }

  /** Called to get current epoch number. */
  def getEpoch: Long = {
    epochLock.synchronized {
      return epoch
    }
  }

  // This method is only called in local-mode.
  override def getMapSizesByExecutorId(
      shuffleId: Int,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    val mapSizesByExecutorId = getPushBasedShuffleMapSizesByExecutorId(
      shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
    assert(mapSizesByExecutorId.enableBatchFetch == true)
    mapSizesByExecutorId.iter
  }

  // This method is only called in local-mode.
  override def getPushBasedShuffleMapSizesByExecutorId(
      shuffleId: Int,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int): MapSizesByExecutorId = {
    logDebug(s"Fetching outputs for shuffle $shuffleId")
    shuffleStatuses.get(shuffleId) match {
      case Some(shuffleStatus) =>
        shuffleStatus.withMapStatuses { statuses =>
          val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
          logDebug(s"Convert map statuses for shuffle $shuffleId, " +
            s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
          MapOutputTracker.convertMapStatuses(
            shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
        }
      case None =>
        MapSizesByExecutorId(Iterator.empty, true)
    }
  }

  // This method is only called in local-mode. Since push based shuffle won't be
  // enabled in local-mode, this method returns empty list.
  override def getMapSizesForMergeResult(
      shuffleId: Int,
      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    Seq.empty.iterator
  }

  // This method is only called in local-mode. Since push based shuffle won't be
  // enabled in local-mode, this method returns empty list.
  override def getMapSizesForMergeResult(
      shuffleId: Int,
      partitionId: Int,
      chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    Seq.empty.iterator
  }

  // This method is only called in local-mode.
  override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = {
    shuffleStatuses(shuffleId).getShufflePushMergerLocations
  }

  override def stop(): Unit = {
    mapOutputTrackerMasterMessages.offer(PoisonPill)
    threadpool.shutdown()
    try {
      sendTracker(StopMapOutputTracker)
    } catch {
      case e: SparkException =>
        logError("Could not tell tracker we are stopping.", e)
    }
    trackerEndpoint = null
    shuffleStatuses.clear()
  }
}

/**
 * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster.
 * Note that this is not used in local-mode; instead, local-mode Executors access the
 * MapOutputTrackerMaster directly (which is possible because the master and worker share a common
 * superclass).
 */
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {

  val mapStatuses: Map[Int, Array[MapStatus]] =
    new ConcurrentHashMap[Int, Array[MapStatus]]().asScala

  val mergeStatuses: Map[Int, Array[MergeStatus]] =
    new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala

  // This must be lazy to ensure that it is initialized when the first task is run and not at
  // executor startup time. At startup time, user-added libraries may not have been
  // downloaded to the executor, causing `isPushBasedShuffleEnabled` to fail when it tries to
  // instantiate a serializer. See the followup to SPARK-36705 for more details.
  private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf, isDriver = false)

  /**
   * [[shufflePushMergerLocations]] tracks shuffle push merger locations for the latest
   * shuffle execution
   *
   * Exposed for testing
   */
  val shufflePushMergerLocations = new ConcurrentHashMap[Int, Seq[BlockManagerId]]().asScala

  /**
   * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching
   * the same shuffle block.
   */
  private val fetchingLock = new KeyLock[Int]

  override def getMapSizesByExecutorId(
      shuffleId: Int,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    val mapSizesByExecutorId = getMapSizesByExecutorIdImpl(
      shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = false)
    assert(mapSizesByExecutorId.enableBatchFetch == true)
    mapSizesByExecutorId.iter
  }

  override def getPushBasedShuffleMapSizesByExecutorId(
      shuffleId: Int,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int): MapSizesByExecutorId = {
    getMapSizesByExecutorIdImpl(
      shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = true)
  }

  private def getMapSizesByExecutorIdImpl(
      shuffleId: Int,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int,
      useMergeResult: Boolean): MapSizesByExecutorId = {
    logDebug(s"Fetching outputs for shuffle $shuffleId")
    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf,
      // enableBatchFetch can be set to false during stage retry when the
      // shuffleDependency.isShuffleMergeFinalizedMarked is set to false, and Driver
      // has already collected the mergedStatus for its shuffle dependency.
      // In this case, boolean check helps to ensure that the unnecessary
      // mergeStatus won't be fetched, thus mergedOutputStatuses won't be
      // passed to convertMapStatuses. See details in [SPARK-37023].
      if (useMergeResult) fetchMergeResult else false)
    try {
      val actualEndMapIndex =
        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
      logDebug(s"Convert map statuses for shuffle $shuffleId, " +
        s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
      MapOutputTracker.convertMapStatuses(
        shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex,
          actualEndMapIndex, Option(mergedOutputStatuses))
    } catch {
      case e: MetadataFetchFailedException =>
        // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
        mapStatuses.clear()
        mergeStatuses.clear()
        throw e
    }
  }

  override def getMapSizesForMergeResult(
      shuffleId: Int,
      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
    // Fetch the map statuses and merge statuses again since they might have already been
    // cleared by another task running in the same executor.
    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf, fetchMergeResult)
    try {
      val mergeStatus = mergeResultStatuses(partitionId)
      // If the original MergeStatus is no longer available, we cannot identify the list of
      // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case.
      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
      // Use the MergeStatus's partition level bitmap since we are doing partition level fallback
      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
        mapOutputStatuses, mergeStatus.tracker)
    } catch {
      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it
      case e: MetadataFetchFailedException =>
        mapStatuses.clear()
        mergeStatuses.clear()
        throw e
    }
  }

  override def getMapSizesForMergeResult(
      shuffleId: Int,
      partitionId: Int,
      chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
    // Fetch the map statuses and merge statuses again since they might have already been
    // cleared by another task running in the same executor.
    val (mapOutputStatuses, _) = getStatuses(shuffleId, conf, fetchMergeResult)
    try {
      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses,
        chunkTracker)
    } catch {
      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
      case e: MetadataFetchFailedException =>
        mapStatuses.clear()
        mergeStatuses.clear()
        throw e
    }
  }

  override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = {
    shufflePushMergerLocations.getOrElse(shuffleId, getMergerLocations(shuffleId))
  }

  private def getMergerLocations(shuffleId: Int): Seq[BlockManagerId] = {
    fetchingLock.withLock(shuffleId) {
      var fetchedMergers = shufflePushMergerLocations.get(shuffleId).orNull
      if (null == fetchedMergers) {
        fetchedMergers =
          askTracker[Seq[BlockManagerId]](GetShufflePushMergerLocations(shuffleId))
        if (fetchedMergers.nonEmpty) {
          shufflePushMergerLocations(shuffleId) = fetchedMergers
        } else {
          fetchedMergers = Seq.empty[BlockManagerId]
        }
      }
      fetchedMergers
    }
  }

  /**
   * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled
   * for a given shuffle ID. NOTE: clients MUST synchronize
   * on this array when reading it, because on the driver, we may be changing it in place.
   *
   * (It would be nice to remove this restriction in the future.)
   */
  private def getStatuses(
      shuffleId: Int,
      conf: SparkConf,
      canFetchMergeResult: Boolean): (Array[MapStatus], Array[MergeStatus]) = {
    if (canFetchMergeResult) {
      val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
      val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull

      if (mapOutputStatuses == null || mergeOutputStatuses == null) {
        logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them")
        val startTimeNs = System.nanoTime()
        fetchingLock.withLock(shuffleId) {
          var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
          var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull
          if (fetchedMapStatuses == null || fetchedMergeStatuses == null) {
            logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
            val fetchedBytes =
              askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId))
            try {
              fetchedMapStatuses =
                MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf)
              fetchedMergeStatuses =
                MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, conf)
            } catch {
              case e: SparkException =>
                throw new MetadataFetchFailedException(shuffleId, -1,
                  s"Unable to deserialize broadcasted map/merge statuses" +
                    s" for shuffle $shuffleId: " + e.getCause)
            }
            logInfo("Got the map/merge output locations")
            mapStatuses.put(shuffleId, fetchedMapStatuses)
            mergeStatuses.put(shuffleId, fetchedMergeStatuses)
          }
          logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " +
            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
          (fetchedMapStatuses, fetchedMergeStatuses)
        }
      } else {
        (mapOutputStatuses, mergeOutputStatuses)
      }
    } else {
      val statuses = mapStatuses.get(shuffleId).orNull
      if (statuses == null) {
        logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
        val startTimeNs = System.nanoTime()
        fetchingLock.withLock(shuffleId) {
          var fetchedStatuses = mapStatuses.get(shuffleId).orNull
          if (fetchedStatuses == null) {
            logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
            val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
            try {
              fetchedStatuses =
                MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf)
            } catch {
              case e: SparkException =>
                throw new MetadataFetchFailedException(shuffleId, -1,
                  s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " +
                    e.getCause)
            }
            logInfo("Got the map output locations")
            mapStatuses.put(shuffleId, fetchedStatuses)
          }
          logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
          (fetchedStatuses, null)
        }
      } else {
        (statuses, null)
      }
    }
  }

  /** Unregister shuffle data. */
  def unregisterShuffle(shuffleId: Int): Unit = {
    mapStatuses.remove(shuffleId)
    mergeStatuses.remove(shuffleId)
    shufflePushMergerLocations.remove(shuffleId)
  }

  /**
   * Called from executors to update the epoch number, potentially clearing old outputs
   * because of a fetch failure. Each executor task calls this with the latest epoch
   * number on the driver at the time it was created.
   */
  def updateEpoch(newEpoch: Long): Unit = {
    epochLock.synchronized {
      if (newEpoch > epoch) {
        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
        epoch = newEpoch
        mapStatuses.clear()
        mergeStatuses.clear()
        shufflePushMergerLocations.clear()
      }
    }
  }
}

private[spark] object MapOutputTracker extends Logging {

  val ENDPOINT_NAME = "MapOutputTracker"
  private val DIRECT = 0
  private val BROADCAST = 1

  val SHUFFLE_PUSH_MAP_ID = -1

  // Serialize an array of map/merge output locations into an efficient byte format so that we can
  // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will
  // generally be pretty compressible because many outputs will be on the same hostname.
  def serializeOutputStatuses[T <: ShuffleOutputStatus](
      statuses: Array[T],
      broadcastManager: BroadcastManager,
      isLocal: Boolean,
      minBroadcastSize: Int,
      conf: SparkConf): (Array[Byte], Broadcast[Array[Array[Byte]]]) = {
    // ByteArrayOutputStream has the 2GB limit so use ChunkedByteBufferOutputStream instead
    val out = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
    out.write(DIRECT)
    val codec = CompressionCodec.createCodec(conf, conf.get(MAP_STATUS_COMPRESSION_CODEC))
    val objOut = new ObjectOutputStream(codec.compressedOutputStream(out))
    Utils.tryWithSafeFinally {
      // Since statuses can be modified in parallel, sync on it
      statuses.synchronized {
        objOut.writeObject(statuses)
      }
    } {
      objOut.close()
    }
    val chunkedByteBuf = out.toChunkedByteBuffer
    val arrSize = out.size
    if (arrSize >= minBroadcastSize) {
      // Use broadcast instead.
      // Important arr(0) is the tag == DIRECT, ignore that while deserializing !
      // arr is a nested Array so that it can handle over 2GB serialized data
      val arr = chunkedByteBuf.getChunks().map(_.array())
      val bcast = broadcastManager.newBroadcast(arr, isLocal)
      // Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard one
      // This implementation doesn't reallocate the whole memory block but allocates
      // additional buffers. This way no buffers need to be garbage collected and
      // the contents don't have to be copied to the new buffer.
      val out = new ApacheByteArrayOutputStream()
      out.write(BROADCAST)
      val oos = new ObjectOutputStream(codec.compressedOutputStream(out))
      Utils.tryWithSafeFinally {
        oos.writeObject(bcast)
      } {
        oos.close()
      }
      val outArr = out.toByteArray
      logInfo("Broadcast outputstatuses size = " + outArr.length + ", actual size = " + arrSize)
      (outArr, bcast)
    } else {
      (chunkedByteBuf.toArray, null)
    }
  }

  // Opposite of serializeOutputStatuses.
  def deserializeOutputStatuses[T <: ShuffleOutputStatus](
      bytes: Array[Byte], conf: SparkConf): Array[T] = {
    assert (bytes.length > 0)

    def deserializeObject(in: InputStream): AnyRef = {
      val codec = CompressionCodec.createCodec(conf, conf.get(MAP_STATUS_COMPRESSION_CODEC))
      // The ZStd codec is wrapped in a `BufferedInputStream` which avoids overhead excessive
      // of JNI call while trying to decompress small amount of data for each element
      // of `MapStatuses`
      val objIn = new ObjectInputStream(codec.compressedInputStream(in))
      Utils.tryWithSafeFinally {
        objIn.readObject()
      } {
        objIn.close()
      }
    }

    val in = new ByteArrayInputStream(bytes, 1, bytes.length - 1)
    bytes(0) match {
      case DIRECT =>
        deserializeObject(in).asInstanceOf[Array[T]]
      case BROADCAST =>
        try {
          // deserialize the Broadcast, pull .value array out of it, and then deserialize that
          val bcast = deserializeObject(in).asInstanceOf[Broadcast[Array[Array[Byte]]]]
          logInfo("Broadcast outputstatuses size = " + bytes.length +
            ", actual size = " + bcast.value.foldLeft(0L)(_ + _.length))
          val bcastIn = new ChunkedByteBuffer(bcast.value.map(ByteBuffer.wrap)).toInputStream()
          // Important - ignore the DIRECT tag ! Start from offset 1
          bcastIn.skip(1)
          deserializeObject(bcastIn).asInstanceOf[Array[T]]
        } catch {
          case e: IOException =>
            logWarning("Exception encountered during deserializing broadcasted" +
              " output statuses: ", e)
            throw new SparkException("Unable to deserialize broadcasted" +
              " output statuses", e)
        }
      case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
    }
  }

  /**
   * Given an array of map statuses and a range of map output partitions, returns a sequence that,
   * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes
   * stored at that block manager.
   * Note that empty blocks are filtered in the result.
   *
   * If push-based shuffle is enabled and an array of merge statuses is available, prioritize
   * the locations of the merged shuffle partitions over unmerged shuffle blocks.
   *
   * If any of the statuses is null (indicating a missing location due to a failed mapper),
   * throws a FetchFailedException.
   *
   * @param shuffleId Identifier for the shuffle
   * @param startPartition Start of map output partition ID range (included in range)
   * @param endPartition End of map output partition ID range (excluded from range)
   * @param mapStatuses List of map statuses, indexed by map partition index.
   * @param startMapIndex Start Map index.
   * @param endMapIndex End Map index.
   * @param mergeStatuses List of merge statuses, index by reduce ID.
   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
   *         and the second item is a sequence of (shuffle block id, shuffle block size, map index)
   *         tuples describing the shuffle blocks that are stored at that block manager.
   */
  def convertMapStatuses(
      shuffleId: Int,
      startPartition: Int,
      endPartition: Int,
      mapStatuses: Array[MapStatus],
      startMapIndex : Int,
      endMapIndex: Int,
      mergeStatusesOpt: Option[Array[MergeStatus]] = None): MapSizesByExecutorId = {
    assert (mapStatuses != null)
    val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
    var enableBatchFetch = true
    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle
    // partition consists of blocks merged in random order, we are unable to serve map index
    // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of
    // map outputs, it usually indicates skewed partitions which push-based shuffle delegates
    // to AQE to handle.
    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle,
    // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end
    // TODO: map indexes
    if (mergeStatusesOpt.exists(_.exists(_ != null)) && startMapIndex == 0
      && endMapIndex == mapStatuses.length) {
      enableBatchFetch = false
      logDebug(s"Disable shuffle batch fetch as Push based shuffle is enabled for $shuffleId.")
      val mergeStatuses = mergeStatusesOpt.get
      for (partId <- startPartition until endPartition) {
        val mergeStatus = mergeStatuses(partId)
        if (mergeStatus != null && mergeStatus.totalSize > 0) {
          // If MergeStatus is available for the given partition, add location of the
          // pre-merged shuffle partition for this partition ID. Here we create a
          // ShuffleMergedBlockId to indicate this is a merged shuffle block.
          splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) +=
            ((ShuffleMergedBlockId(shuffleId, mergeStatus.shuffleMergeId, partId),
              mergeStatus.totalSize, SHUFFLE_PUSH_MAP_ID))
        }
      }

      // Add location for the mapper shuffle partition blocks
      for ((mapStatus, mapIndex) <- mapStatuses.iterator.zipWithIndex) {
        validateStatus(mapStatus, shuffleId, startPartition)
        for (partId <- startPartition until endPartition) {
          // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper
          // shuffle partition blocks, fetch the original map produced shuffle partition blocks
          val mergeStatus = mergeStatuses(partId)
          if (mergeStatus == null || mergeStatus.totalSize == 0 ||
            !mergeStatus.tracker.contains(mapIndex)) {
            val size = mapStatus.getSizeForBlock(partId)
            if (size != 0) {
              splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) +=
                ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), size, mapIndex))
            }
          }
        }
      }
    } else {
      val iter = mapStatuses.iterator.zipWithIndex
      for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
        validateStatus(status, shuffleId, startPartition)
        for (part <- startPartition until endPartition) {
          val size = status.getSizeForBlock(part)
          if (size != 0) {
            splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
              ((ShuffleBlockId(shuffleId, status.mapId, part), size, mapIndex))
          }
        }
      }
    }

    MapSizesByExecutorId(splitsByAddress.mapValues(_.toSeq).iterator, enableBatchFetch)
  }

  /**
   * Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding
   * to either a merged shuffle partition or a merged shuffle partition chunk, identify
   * the metadata about the shuffle partition blocks that are merged into the merged shuffle
   * partition or partition chunk represented by the bitmap.
   *
   * @param shuffleId Identifier for the shuffle
   * @param partitionId The partition ID of the MergeStatus for which we look for the metadata
   *                    of the merged shuffle partition blocks
   * @param mapStatuses List of map statuses, indexed by map ID
   * @param tracker     bitmap containing mapIndexes that belong to the merged block or merged
   *                    block chunk.
   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
   *         and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
   *         describing the shuffle blocks that are stored at that block manager.
   */
  def getMapStatusesForMergeStatus(
      shuffleId: Int,
      partitionId: Int,
      mapStatuses: Array[MapStatus],
      tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
    assert (mapStatuses != null && tracker != null)
    val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
    for ((status, mapIndex) <- mapStatuses.zipWithIndex) {
      // Only add blocks that are merged
      if (tracker.contains(mapIndex)) {
        MapOutputTracker.validateStatus(status, shuffleId, partitionId)
        splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
          ((ShuffleBlockId(shuffleId, status.mapId, partitionId),
            status.getSizeForBlock(partitionId), mapIndex))
      }
    }
    splitsByAddress.mapValues(_.toSeq).iterator
  }

  def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = {
    if (status == null) {
      val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition"
      logError(errorMessage)
      throw new MetadataFetchFailedException(shuffleId, partition, errorMessage)
    }
  }
}

相关信息

spark 源码目录

相关文章

spark Aggregator 源码

spark BarrierCoordinator 源码

spark BarrierTaskContext 源码

spark BarrierTaskInfo 源码

spark ContextAwareIterator 源码

spark ContextCleaner 源码

spark Dependency 源码

spark ErrorClassesJSONReader 源码

spark ExecutorAllocationClient 源码

spark ExecutorAllocationManager 源码

0  赞