spark PythonRunner 源码
spark PythonRunner 代码
文件路径:/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files => JavaFiles, Path}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
import org.apache.spark.internal.config.Python._
import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY}
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._
/**
* Enumerate the type of command that will be sent to the Python worker
*/
private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 100
val SQL_SCALAR_PANDAS_UDF = 200
val SQL_GROUPED_MAP_PANDAS_UDF = 201
val SQL_GROUPED_AGG_PANDAS_UDF = 202
val SQL_WINDOW_AGG_PANDAS_UDF = 203
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
}
}
private object BasePythonRunner {
private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
private def faultHandlerLogPath(pid: Int): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}
}
/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
* funcs is a list of independent Python functions, each one of them is a list of chained Python
* functions (from bottom to top).
*/
private[spark] abstract class BasePythonRunner[IN, OUT](
protected val funcs: Seq[ChainedPythonFunctions],
protected val evalType: Int,
protected val argOffsets: Array[Array[Int]])
extends Logging {
require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
private val conf = SparkEnv.get.conf
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val simplifiedTraceback: Boolean = false
// All the Python functions should have the same exec, version and envvars.
protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars
protected val pythonExec: String = funcs.head.funcs.head.pythonExec
protected val pythonVer: String = funcs.head.funcs.head.pythonVer
// TODO: support accumulator in multiple UDF
protected val accumulator: PythonAccumulatorV2 = funcs.head.funcs.head.accumulator
// Python accumulator is always set in production except in tests. See SPARK-27893
private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator)
// Expose a ServerSocket to support method calls via socket from Python side.
private[spark] var serverSocket: Option[ServerSocket] = None
// Authentication helper used when serving method calls via socket from Python side.
private lazy val authHelper = new SocketAuthHelper(conf)
// each python worker gets an equal part of the allocation. the worker pool will grow to the
// number of concurrent tasks, which is determined by the number of cores in this executor.
private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = {
mem.map(_ / cores)
}
def compute(
inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext): Iterator[OUT] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
// Get the executor cores and pyspark memory, they are passed via the local properties when
// the user specified them in a ResourceProfile.
val execCoresProp = Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY))
val memoryMb = Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong)
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
// if OMP_NUM_THREADS is not explicitly set, override it with the number of cores
if (conf.getOption("spark.executorEnv.OMP_NUM_THREADS").isEmpty) {
// SPARK-28843: limit the OpenMP thread pool to the number of cores assigned to this executor
// this avoids high memory consumption with pandas/numpy because of a large OpenMP thread pool
// see https://github.com/numpy/numpy/issues/10455
execCoresProp.foreach(envVars.put("OMP_NUM_THREADS", _))
}
envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
// SPARK-30299 this could be wrong with standalone mode when executor
// cores might not be correct because it defaults to all cores on the box.
val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES))
val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores)
if (workerMemoryMb.isDefined) {
envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString)
}
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
if (faultHandlerEnabled) {
envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
}
val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
// sure there is only one winner that is going to release or close the worker.
val releasedOrClosed = new AtomicBoolean(false)
// Start a thread to feed the process input from our parent's iterator
val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context)
context.addTaskCompletionListener[Unit] { _ =>
writerThread.shutdownOnTaskCompletion()
if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}
writerThread.start()
new WriterMonitorThread(SparkEnv.get, worker, writerThread, context).start()
if (reuseWorker) {
val key = (worker, context.taskAttemptId)
// SPARK-35009: avoid creating multiple monitor threads for the same python worker
// and task context
if (PythonRunner.runningMonitorThreads.add(key)) {
new MonitorThread(SparkEnv.get, worker, context).start()
}
} else {
new MonitorThread(SparkEnv.get, worker, context).start()
}
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = newReaderIterator(
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context)
new InterruptibleIterator(context, stdoutIterator)
}
protected def newWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext): WriterThread
protected def newReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]
/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
abstract class WriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {
@volatile private var _exception: Throwable = null
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
setDaemon(true)
/** Contains the throwable thrown while writing the parent iterator to the Python process. */
def exception: Option[Throwable] = Option(_exception)
/**
* Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur
* due to cleanup.
*/
def shutdownOnTaskCompletion(): Unit = {
assert(context.isCompleted)
this.interrupt()
// Task completion listeners that run after this method returns may invalidate
// `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized
// reader, a task completion listener will free the underlying off-heap buffers. If the writer
// thread is still running when `inputIterator` is invalidated, it can cause a use-after-free
// bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer
// thread to exit before returning.
this.join()
}
/**
* Writes a command section to the stream connected to the Python worker.
*/
protected def writeCommand(dataOut: DataOutputStream): Unit
/**
* Writes input data to the stream connected to the Python worker.
*/
protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
override def run(): Unit = Utils.logUncaughtExceptions {
try {
TaskContext.setTaskContext(context)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// Init a ServerSocket to accept method calls from Python side.
val isBarrier = context.isInstanceOf[BarrierTaskContext]
if (isBarrier) {
serverSocket = Some(new ServerSocket(/* port */ 0,
/* backlog */ 1,
InetAddress.getByName("localhost")))
// A call to accept() for ServerSocket shall block infinitely.
serverSocket.foreach(_.setSoTimeout(0))
new Thread("accept-connections") {
setDaemon(true)
override def run(): Unit = {
while (!serverSocket.get.isClosed()) {
var sock: Socket = null
try {
sock = serverSocket.get.accept()
// Wait for function call from python side.
sock.setSoTimeout(10000)
authHelper.authClient(sock)
val input = new DataInputStream(sock.getInputStream())
val requestMethod = input.readInt()
// The BarrierTaskContext function may wait infinitely, socket shall not timeout
// before the function finishes.
sock.setSoTimeout(0)
requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
barrierAndServe(requestMethod, sock)
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
val length = input.readInt()
val message = new Array[Byte](length)
input.readFully(message)
barrierAndServe(requestMethod, sock, new String(message, UTF_8))
case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
sock.getOutputStream))
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
}
} catch {
case e: SocketException if e.getMessage.contains("Socket closed") =>
// It is possible that the ServerSocket is not closed, but the native socket
// has already been closed, we shall catch and silently ignore this case.
} finally {
if (sock != null) {
sock.close()
}
}
}
}
}.start()
}
val secret = if (isBarrier) {
authHelper.secret
} else {
""
}
// Close ServerSocket on task completion.
serverSocket.foreach { server =>
context.addTaskCompletionListener[Unit](_ => server.close())
}
val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
if (boundPort == -1) {
val message = "ServerSocket failed to bind to Java side."
logError(message)
throw new SparkException(message)
} else if (isBarrier) {
logDebug(s"Started ServerSocket on port $boundPort.")
}
// Write out the TaskContextInfo
dataOut.writeBoolean(isBarrier)
dataOut.writeInt(boundPort)
val secretBytes = secret.getBytes(UTF_8)
dataOut.writeInt(secretBytes.length)
dataOut.write(secretBytes, 0, secretBytes.length)
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
dataOut.writeLong(context.taskAttemptId())
dataOut.writeInt(context.cpus())
val resources = context.resources()
dataOut.writeInt(resources.size)
resources.foreach { case (k, v) =>
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v.name, dataOut)
dataOut.writeInt(v.addresses.size)
v.addresses.foreach { case addr =>
PythonRDD.writeUTF(addr, dataOut)
}
}
val localProps = context.getLocalProperties.asScala
dataOut.writeInt(localProps.size)
localProps.foreach { case (k, v) =>
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val addedBids = newBids.diff(oldBids)
val cnt = toRemove.size + addedBids.size
val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty
dataOut.writeBoolean(needsDecryptionServer)
dataOut.writeInt(cnt)
def sendBidsToRemove(): Unit = {
for (bid <- toRemove) {
// remove the broadcast from worker
dataOut.writeLong(-bid - 1) // bid >= 0
oldBids.remove(bid)
}
}
if (needsDecryptionServer) {
// if there is encryption, we setup a server which reads the encrypted files, and sends
// the decrypted data to python
val idsAndFiles = broadcastVars.flatMap { broadcast =>
if (!oldBids.contains(broadcast.id)) {
Some((broadcast.id, broadcast.value.path))
} else {
None
}
}
val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
dataOut.writeInt(server.port)
logTrace(s"broadcast decryption server setup on ${server.port}")
PythonRDD.writeUTF(server.secret, dataOut)
sendBidsToRemove()
idsAndFiles.foreach { case (id, _) =>
// send new broadcast
dataOut.writeLong(id)
oldBids.add(id)
}
dataOut.flush()
logTrace("waiting for python to read decrypted broadcast data from server")
server.waitTillBroadcastDataSent()
logTrace("done sending decrypted data to python")
} else {
sendBidsToRemove()
for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
}
dataOut.flush()
dataOut.writeInt(evalType)
writeCommand(dataOut)
writeIteratorToStream(dataOut)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
if (context.isCompleted || context.isInterrupted) {
logDebug("Exception/NonFatal Error thrown after task completion (likely due to " +
"cleanup)", t)
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
} else {
// We must avoid throwing exceptions/NonFatals here, because the thread uncaught
// exception handler will kill the whole executor (see
// org.apache.spark.executor.Executor).
_exception = t
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
}
}
}
/**
* Gateway to call BarrierTaskContext methods.
*/
def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = {
require(
serverSocket.isDefined,
"No available ServerSocket to redirect the BarrierTaskContext method call."
)
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
val messages = requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
context.asInstanceOf[BarrierTaskContext].barrier()
Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
context.asInstanceOf[BarrierTaskContext].allGather(message)
}
out.writeInt(messages.length)
messages.foreach(writeUTF(_, out))
} catch {
case e: SparkException =>
writeUTF(e.getMessage, out)
} finally {
out.close()
}
}
def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
}
abstract class ReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext)
extends Iterator[OUT] {
private var nextObj: OUT = _
private var eos = false
override def hasNext: Boolean = nextObj != null || {
if (!eos) {
nextObj = read()
hasNext
} else {
false
}
}
override def next(): OUT = {
if (hasNext) {
val obj = nextObj
nextObj = null.asInstanceOf[OUT]
obj
} else {
Iterator.empty.next()
}
}
/**
* Reads next object from the stream.
* When the stream reaches end of data, needs to process the following sections,
* and then returns null.
*/
protected def read(): OUT
protected def handleTimingData(): Unit = {
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
}
protected def handlePythonException(): PythonException = {
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
new PythonException(new String(obj, StandardCharsets.UTF_8),
writerThread.exception.orNull)
}
protected def handleEndOfDataSection(): Unit = {
// We've finished the data section of the output, but we can still
// read some accumulator updates:
val numAccumulatorUpdates = stream.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
maybeAccumulator.foreach(_.add(update))
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
}
}
eos = true
}
protected val handleException: PartialFunction[Throwable, OUT] = {
case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
throw writerThread.exception.get
case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
val path = BasePythonRunner.faultHandlerLogPath(pid.get)
val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
JavaFiles.deleteIfExists(path)
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof)
case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
}
/**
* It is necessary to have a monitor thread for python workers if the user cancels with
* interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
* threads can block indefinitely.
*/
class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
extends Thread(s"Worker Monitor for $pythonExec") {
/** How long to wait before killing the python worker if a task cannot be interrupted. */
private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT)
setDaemon(true)
private def monitorWorker(): Unit = {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
if (!context.isCompleted) {
Thread.sleep(taskKillTimeout)
if (!context.isCompleted) {
try {
// Mimic the task name used in `Executor` to help the user find out the task to blame.
val taskName = s"${context.partitionId}.${context.attemptNumber} " +
s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
}
}
}
}
override def run(): Unit = {
try {
monitorWorker()
} finally {
if (reuseWorker) {
val key = (worker, context.taskAttemptId)
PythonRunner.runningMonitorThreads.remove(key)
}
}
}
}
/**
* This thread monitors the WriterThread and kills it in case of deadlock.
*
* A deadlock can arise if the task completes while the writer thread is sending input to the
* Python process (e.g. due to the use of `take()`), and the Python process is still producing
* output. When the inputs are sufficiently large, this can result in a deadlock due to the use of
* blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket.
*/
class WriterMonitorThread(
env: SparkEnv, worker: Socket, writerThread: WriterThread, context: TaskContext)
extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") {
/**
* How long to wait before closing the socket if the writer thread has not exited after the task
* ends.
*/
private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT)
setDaemon(true)
override def run(): Unit = {
// Wait until the task is completed (or the writer thread exits, in which case this thread has
// nothing to do).
while (!context.isCompleted && writerThread.isAlive) {
Thread.sleep(2000)
}
if (writerThread.isAlive) {
Thread.sleep(taskKillTimeout)
// If the writer thread continues running, this indicates a deadlock. Kill the worker to
// resolve the deadlock.
if (writerThread.isAlive) {
try {
// Mimic the task name used in `Executor` to help the user find out the task to blame.
val taskName = s"${context.partitionId}.${context.attemptNumber} " +
s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
logWarning(
s"Detected deadlock while completing task $taskName: " +
"Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
}
}
}
}
}
}
private[spark] object PythonRunner {
// already running worker monitor threads for worker and task attempts ID pairs
val runningMonitorThreads = ConcurrentHashMap.newKeySet[(Socket, Long)]()
def apply(func: PythonFunction): PythonRunner = {
new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))))
}
}
/**
* A helper class to run Python mapPartition in Spark.
*/
private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
extends BasePythonRunner[Array[Byte], Array[Byte]](
funcs, PythonEvalType.NON_UDF, Array(Array(0))) {
protected override def newWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[Array[Byte]],
partitionIndex: Int,
context: TaskContext): WriterThread = {
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command.toArray)
}
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
}
}
}
protected override def newReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
protected override def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.emptyByteArray
case SpecialLengths.TIMING_DATA =>
handleTimingData()
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
throw handlePythonException()
case SpecialLengths.END_OF_DATA_SECTION =>
handleEndOfDataSection()
null
}
} catch handleException
}
}
}
}
private[spark] object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
val END_OF_STREAM = -4
val NULL = -5
val START_ARROW_STREAM = -6
}
private[spark] object BarrierTaskContextMessageProtocol {
val BARRIER_FUNCTION = 1
val ALL_GATHER_FUNCTION = 2
val BARRIER_RESULT_SUCCESS = "success"
val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side."
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦