spark ShuffleBlockPusher 源码
spark ShuffleBlockPusher 代码
文件路径:/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import java.io.{File, FileNotFoundException}
import java.net.ConnectException
import java.nio.ByteBuffer
import java.util.concurrent.ExecutorService
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import scala.util.control.NonFatal
import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.annotation.Since
import org.apache.spark.executor.{CoarseGrainedExecutorBackend, ExecutorBackend}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.server.BlockPushNonFatalFailure
import org.apache.spark.network.shuffle.BlockPushingListener
import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
import org.apache.spark.network.util.TransportConf
import org.apache.spark.shuffle.ShuffleBlockPusher._
import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId}
import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Used for pushing shuffle blocks to remote shuffle services when push shuffle is enabled.
* When push shuffle is enabled, it is created after the shuffle writer finishes writing the shuffle
* file and initiates the block push process.
*
* @param conf spark configuration
*/
@Since("3.2.0")
private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
private[this] val maxBlockSizeToPush = conf.get(SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH)
private[this] val maxBlockBatchSize = conf.get(SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH)
private[this] val maxBytesInFlight = conf.get(REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024
private[this] val maxReqsInFlight = conf.get(REDUCER_MAX_REQS_IN_FLIGHT)
private[this] val maxBlocksInFlightPerAddress = conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
private[shuffle] var bytesInFlight = 0L
private[this] var reqsInFlight = 0
private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
private[this] val deferredPushRequests = new HashMap[BlockManagerId, Queue[PushRequest]]()
private[this] val pushRequests = new Queue[PushRequest]
private[this] val errorHandler = createErrorHandler()
// VisibleForTesting
private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
private[this] var shuffleId = -1
private[this] var mapIndex = -1
private[this] var shuffleMergeId = -1
private[this] var pushCompletionNotified = false
// VisibleForTesting
private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
new BlockPushErrorHandler() {
// For a connection exception against a particular host, we will stop pushing any
// blocks to just that host and continue push blocks to other hosts. So, here push of
// all blocks will only stop when it is "Too Late" or "Invalid Block push.
// Also see updateStateAndCheckIfPushMore.
override def shouldRetryError(t: Throwable): Boolean = {
// If it is a FileNotFoundException originating from the client while pushing the shuffle
// blocks to the server, then we stop pushing all the blocks because this indicates the
// shuffle files are deleted and subsequent block push will also fail.
if (t.getCause != null && t.getCause.isInstanceOf[FileNotFoundException]) {
return false
}
// If the block is too late or the invalid block push or the attempt is not the latest one,
// there is no need to retry it
!(t.isInstanceOf[BlockPushNonFatalFailure] &&
BlockPushNonFatalFailure.
shouldNotRetryErrorCode(t.asInstanceOf[BlockPushNonFatalFailure].getReturnCode));
}
}
}
// VisibleForTesting
private[shuffle] def isPushCompletionNotified = pushCompletionNotified
/**
* Initiates the block push.
*
* @param dataFile mapper generated shuffle data file
* @param partitionLengths array of shuffle block size so we can tell shuffle block
* @param dep shuffle dependency to get shuffle ID and the location of remote shuffle
* services to push local shuffle blocks
* @param mapIndex map index of the shuffle map task
*/
private[shuffle] def initiateBlockPush(
dataFile: File,
partitionLengths: Array[Long],
dep: ShuffleDependency[_, _, _],
mapIndex: Int): Unit = {
val numPartitions = dep.partitioner.numPartitions
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
this.shuffleId = dep.shuffleId
this.shuffleMergeId = dep.shuffleMergeId
this.mapIndex = mapIndex
val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId,
dep.shuffleMergeId, dataFile, partitionLengths, dep.getMergerLocs, transportConf)
// Randomize the orders of the PushRequest, so different mappers pushing blocks at the same
// time won't be pushing the same ranges of shuffle partitions.
pushRequests ++= Utils.randomize(requests)
if (pushRequests.isEmpty) {
notifyDriverAboutPushCompletion()
} else {
submitTask(() => {
tryPushUpToMax()
})
}
}
private[shuffle] def tryPushUpToMax(): Unit = {
try {
pushUpToMax()
} catch {
case NonFatal(e) =>
logWarning("Failure during push so stopping the block push", e)
}
}
/**
* Triggers the push. It's a separate method for testing.
* VisibleForTesting
*/
protected def submitTask(task: Runnable): Unit = {
if (BLOCK_PUSHER_POOL != null && !BLOCK_PUSHER_POOL.isShutdown) {
BLOCK_PUSHER_POOL.execute(task)
}
}
/**
* Since multiple block push threads could potentially be calling pushUpToMax for the same
* mapper, we synchronize access to this method so that only one thread can push blocks for
* a given mapper. This helps to simplify access to the shared states. The down side of this
* is that we could unnecessarily block other mappers' block pushes if all the threads
* are occupied by block pushes from the same mapper.
*
* This code is similar to ShuffleBlockFetcherIterator#fetchUpToMaxBytes in how it throttles
* the data transfer between shuffle client/server.
*/
private def pushUpToMax(): Unit = synchronized {
// Process any outstanding deferred push requests if possible.
if (deferredPushRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <- deferredPushRequests) {
while (isRemoteBlockPushable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
logDebug(s"Processing deferred push request for $remoteAddress with "
+ s"${request.blocks.length} blocks")
sendRequest(request)
if (defReqQueue.isEmpty) {
deferredPushRequests -= remoteAddress
}
}
}
}
// Process any regular push requests if possible.
while (isRemoteBlockPushable(pushRequests)) {
val request = pushRequests.dequeue()
val remoteAddress = request.address
if (isRemoteAddressMaxedOut(remoteAddress, request)) {
logDebug(s"Deferring push request for $remoteAddress with ${request.blocks.size} blocks")
deferredPushRequests.getOrElseUpdate(remoteAddress, new Queue[PushRequest]())
.enqueue(request)
} else {
sendRequest(request)
}
}
def isRemoteBlockPushable(pushReqQueue: Queue[PushRequest]): Boolean = {
pushReqQueue.nonEmpty &&
(bytesInFlight == 0 ||
(reqsInFlight + 1 <= maxReqsInFlight &&
bytesInFlight + pushReqQueue.front.size <= maxBytesInFlight))
}
// Checks if sending a new push request will exceed the max no. of blocks being pushed to a
// given remote address.
def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: PushRequest): Boolean = {
(numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0)
+ request.blocks.size) > maxBlocksInFlightPerAddress
}
}
/**
* Push blocks to remote shuffle server. The callback listener will invoke #pushUpToMax again
* to trigger pushing the next batch of blocks once some block transfer is done in the current
* batch. This way, we decouple the map task from the block push process, since it is netty
* client thread instead of task execution thread which takes care of majority of the block
* pushes.
*/
private def sendRequest(request: PushRequest): Unit = {
bytesInFlight += request.size
reqsInFlight += 1
numBlocksInFlightPerAddress(request.address) = numBlocksInFlightPerAddress.getOrElseUpdate(
request.address, 0) + request.blocks.length
val sizeMap = request.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val address = request.address
val blockIds = request.blocks.map(_._1.toString)
val remainingBlocks = new HashSet[String]() ++= blockIds
val blockPushListener = new BlockPushingListener {
// Initiating a connection and pushing blocks to a remote shuffle service is always handled by
// the block-push-threads. We should not initiate the connection creation in the
// blockPushListener callbacks which are invoked by the netty eventloop because:
// 1. TransportClient.createConnection(...) blocks for connection to be established and it's
// recommended to avoid any blocking operations in the eventloop;
// 2. The actual connection creation is a task that gets added to the task queue of another
// eventloop which could have eventloops eventually blocking each other.
// Once the blockPushListener is notified of the block push success or failure, we
// just delegate it to block-push-threads.
def handleResult(result: PushResult): Unit = {
submitTask(() => {
if (updateStateAndCheckIfPushMore(
sizeMap(result.blockId), address, remainingBlocks, result)) {
tryPushUpToMax()
}
})
}
override def onBlockPushSuccess(blockId: String, data: ManagedBuffer): Unit = {
logTrace(s"Push for block $blockId to $address successful.")
handleResult(PushResult(blockId, null))
}
override def onBlockPushFailure(blockId: String, exception: Throwable): Unit = {
// check the message or it's cause to see it needs to be logged.
if (!errorHandler.shouldLogError(exception)) {
logTrace(s"Pushing block $blockId to $address failed.", exception)
} else {
logWarning(s"Pushing block $blockId to $address failed.", exception)
}
handleResult(PushResult(blockId, exception))
}
}
// In addition to randomizing the order of the push requests, further randomize the order
// of blocks within the push request to further reduce the likelihood of shuffle server side
// collision of pushed blocks. This does not increase the cost of reading unmerged shuffle
// files on the executor side, because we are still reading MB-size chunks and only randomize
// the in-memory sliced buffers post reading.
val (blockPushIds, blockPushBuffers) = Utils.randomize(blockIds.zip(
sliceReqBufferIntoBlockBuffers(request.reqBuffer, request.blocks.map(_._2)))).unzip
SparkEnv.get.blockManager.blockStoreClient.pushBlocks(
address.host, address.port, blockPushIds.toArray,
blockPushBuffers.toArray, blockPushListener)
}
/**
* Given the ManagedBuffer representing all the continuous blocks inside the shuffle data file
* for a PushRequest and an array of individual block sizes, load the buffer from disk into
* memory and slice it into multiple smaller buffers representing each block.
*
* With nio ByteBuffer, the individual block buffers share data with the initial in memory
* buffer loaded from disk. Thus only one copy of the block data is kept in memory.
* @param reqBuffer A {{FileSegmentManagedBuffer}} representing all the continuous blocks in
* the shuffle data file for a PushRequest
* @param blockSizes Array of block sizes
* @return Array of in memory buffer for each individual block
*/
private def sliceReqBufferIntoBlockBuffers(
reqBuffer: ManagedBuffer,
blockSizes: Seq[Int]): Array[ManagedBuffer] = {
if (blockSizes.size == 1) {
Array(reqBuffer)
} else {
val inMemoryBuffer = reqBuffer.nioByteBuffer()
val blockOffsets = new Array[Int](blockSizes.size)
var offset = 0
for (index <- blockSizes.indices) {
blockOffsets(index) = offset
offset += blockSizes(index)
}
blockOffsets.zip(blockSizes).map {
case (offset, size) =>
new NioManagedBuffer(inMemoryBuffer.duplicate()
.position(offset)
.limit(offset + size).asInstanceOf[ByteBuffer].slice())
}.toArray
}
}
/**
* Updates the stats and based on the previous push result decides whether to push more blocks
* or stop.
*
* @param bytesPushed number of bytes pushed.
* @param address address of the remote service
* @param remainingBlocks remaining blocks
* @param pushResult result of the last push
* @return true if more blocks should be pushed; false otherwise.
*/
private def updateStateAndCheckIfPushMore(
bytesPushed: Long,
address: BlockManagerId,
remainingBlocks: HashSet[String],
pushResult: PushResult): Boolean = synchronized {
remainingBlocks -= pushResult.blockId
bytesInFlight -= bytesPushed
numBlocksInFlightPerAddress(address) -= 1
if (remainingBlocks.isEmpty) {
reqsInFlight -= 1
}
if (pushResult.failure != null && pushResult.failure.getCause.isInstanceOf[ConnectException]) {
// Remove all the blocks for this address just once because removing from pushRequests
// is expensive. If there is a ConnectException for the first block, all the subsequent
// blocks to that address will fail, so should avoid removing multiple times.
if (!unreachableBlockMgrs.contains(address)) {
var removed = 0
unreachableBlockMgrs.add(address)
removed += pushRequests.dequeueAll(req => req.address == address).length
removed += deferredPushRequests.remove(address).map(_.length).getOrElse(0)
logWarning(s"Received a ConnectException from $address. " +
s"Dropping $removed push-requests and " +
s"not pushing any more blocks to this address.")
}
}
if (pushResult.failure != null && !errorHandler.shouldRetryError(pushResult.failure)) {
logDebug(s"Encountered an exception from $address which indicates that push needs to " +
s"stop.")
return false
} else {
if (reqsInFlight <= 0 && pushRequests.isEmpty && deferredPushRequests.isEmpty) {
notifyDriverAboutPushCompletion()
}
remainingBlocks.isEmpty && (pushRequests.nonEmpty || deferredPushRequests.nonEmpty)
}
}
/**
* Notify the driver about all the blocks generated by the current map task having been pushed.
* This enables the DAGScheduler to finalize shuffle merge as soon as sufficient map tasks have
* completed push instead of always waiting for a fixed amount of time.
*
* VisibleForTesting
*/
protected def notifyDriverAboutPushCompletion(): Unit = {
assert(shuffleId >= 0 && mapIndex >= 0)
if (!pushCompletionNotified) {
SparkEnv.get.executorBackend match {
case Some(cb: CoarseGrainedExecutorBackend) =>
cb.notifyDriverAboutPushCompletion(shuffleId, shuffleMergeId, mapIndex)
case Some(eb: ExecutorBackend) =>
logWarning(s"Currently $eb doesn't support push-based shuffle")
case None =>
}
pushCompletionNotified = true
}
}
/**
* Convert the shuffle data file of the current mapper into a list of PushRequest. Basically,
* continuous blocks in the shuffle file are grouped into a single request to allow more
* efficient read of the block data. Each mapper for a given shuffle will receive the same
* list of BlockManagerIds as the target location to push the blocks to. All mappers in the
* same shuffle will map shuffle partition ranges to individual target locations in a consistent
* manner to make sure each target location receives shuffle blocks belonging to the same set
* of partition ranges. 0-length blocks and blocks that are large enough will be skipped.
*
* @param numPartitions number of shuffle partitions in the shuffle file
* @param partitionId map index of the current mapper
* @param shuffleId shuffleId of current shuffle
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param dataFile shuffle data file
* @param partitionLengths array of sizes of blocks in the shuffle data file
* @param mergerLocs target locations to push blocks to
* @param transportConf transportConf used to create FileSegmentManagedBuffer
* @return List of the PushRequest, randomly shuffled.
*
* VisibleForTesting
*/
private[shuffle] def prepareBlockPushRequests(
numPartitions: Int,
partitionId: Int,
shuffleId: Int,
shuffleMergeId: Int,
dataFile: File,
partitionLengths: Array[Long],
mergerLocs: Seq[BlockManagerId],
transportConf: TransportConf): Seq[PushRequest] = {
var offset = 0L
var currentReqSize = 0
var currentReqOffset = 0L
var currentMergerId = 0
val numMergers = mergerLocs.length
val requests = new ArrayBuffer[PushRequest]
var blocks = new ArrayBuffer[(BlockId, Int)]
for (reduceId <- 0 until numPartitions) {
val blockSize = partitionLengths(reduceId)
logDebug(
s"Block ${ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId,
reduceId)} is of size $blockSize")
// Skip 0-length blocks and blocks that are large enough
if (blockSize > 0) {
val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers),
numMergers - 1).asInstanceOf[Int]
// Start a new PushRequest if the current request goes beyond the max batch size,
// or the number of blocks in the current request goes beyond the limit per destination,
// or the next block push location is for a different shuffle service, or the next block
// exceeds the max block size to push limit. This guarantees that each PushRequest
// represents continuous blocks in the shuffle file to be pushed to the same shuffle
// service, and does not go beyond existing limitations.
if (currentReqSize + blockSize <= maxBlockBatchSize
&& blocks.size < maxBlocksInFlightPerAddress
&& mergerId == currentMergerId && blockSize <= maxBlockSizeToPush) {
// Add current block to current batch
currentReqSize += blockSize.toInt
} else {
if (blocks.nonEmpty) {
// Convert the previous batch into a PushRequest
requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
blocks = new ArrayBuffer[(BlockId, Int)]
}
// Start a new batch
currentReqSize = 0
// Set currentReqOffset to -1 so we are able to distinguish between the initial value
// of currentReqOffset and when we are about to start a new batch
currentReqOffset = -1
currentMergerId = mergerId
}
// Only push blocks under the size limit
if (blockSize <= maxBlockSizeToPush) {
val blockSizeInt = blockSize.toInt
blocks += ((ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId,
reduceId), blockSizeInt))
// Only update currentReqOffset if the current block is the first in the request
if (currentReqOffset == -1) {
currentReqOffset = offset
}
if (currentReqSize == 0) {
currentReqSize += blockSizeInt
}
}
}
offset += blockSize
}
// Add in the final request
if (blocks.nonEmpty) {
requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
}
requests.toSeq
}
// Visible for testing
protected def createRequestBuffer(
conf: TransportConf,
dataFile: File,
offset: Long,
length: Long): ManagedBuffer = {
new FileSegmentManagedBuffer(conf, dataFile, offset, length)
}
}
private[spark] object ShuffleBlockPusher {
/**
* A request to push blocks to a remote shuffle service
* @param address remote shuffle service location to push blocks to
* @param blocks list of block IDs and their sizes
* @param reqBuffer a chunk of data in the shuffle data file corresponding to the continuous
* blocks represented in this request
*/
private[spark] case class PushRequest(
address: BlockManagerId,
blocks: Seq[(BlockId, Int)],
reqBuffer: ManagedBuffer) {
val size = blocks.map(_._2).sum
}
/**
* Result of the block push.
* @param blockId blockId
* @param failure exception if the push was unsuccessful; null otherwise;
*/
private case class PushResult(blockId: String, failure: Throwable)
private val BLOCK_PUSHER_POOL: ExecutorService = {
val conf = SparkEnv.get.conf
if (Utils.isPushBasedShuffleEnabled(conf,
isDriver = SparkContext.DRIVER_IDENTIFIER == SparkEnv.get.executorId)) {
val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS)
.getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1))
ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread")
} else {
null
}
}
/**
* Stop the shuffle pusher pool if it isn't null.
*/
private[spark] def stop(): Unit = {
if (BLOCK_PUSHER_POOL != null) {
BLOCK_PUSHER_POOL.shutdown()
}
}
}
相关信息
相关文章
spark BlockStoreShuffleReader 源码
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦