spark Inbox 源码
spark Inbox 代码
文件路径:/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rpc.netty
import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
private[netty] sealed trait InboxMessage
private[netty] case class OneWayMessage(
senderAddress: RpcAddress,
content: Any) extends InboxMessage
private[netty] case class RpcMessage(
senderAddress: RpcAddress,
content: Any,
context: NettyRpcCallContext) extends InboxMessage
private[netty] case object OnStart extends InboxMessage
private[netty] case object OnStop extends InboxMessage
/** A message to tell all endpoints that a remote process has connected. */
private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage
/** A message to tell all endpoints that a remote process has disconnected. */
private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage
/** A message to tell all endpoints that a network error has happened. */
private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress)
extends InboxMessage
/**
* An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
*/
private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)
extends Logging {
inbox => // Give this an alias so we can use it more clearly in closures.
@GuardedBy("this")
protected val messages = new java.util.LinkedList[InboxMessage]()
/** True if the inbox (and its associated endpoint) is stopped. */
@GuardedBy("this")
private var stopped = false
/** Allow multiple threads to process messages at the same time. */
@GuardedBy("this")
private var enableConcurrent = false
/** The number of threads processing messages for this inbox. */
@GuardedBy("this")
private var numActiveThreads = 0
// OnStart should be the first message to process
inbox.synchronized {
messages.add(OnStart)
}
/**
* Process stored messages.
*/
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
} else {
return
}
}
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
case e: Throwable =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
assert(activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")
case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
numActiveThreads -= 1
return
}
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}
def post(message: InboxMessage): Unit = inbox.synchronized {
if (stopped) {
// We already put "OnStop" into "messages", so we should drop further messages
onDrop(message)
} else {
messages.add(message)
false
}
}
def stop(): Unit = inbox.synchronized {
// The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
// message
if (!stopped) {
// We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
// thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
// safely.
enableConcurrent = false
stopped = true
messages.add(OnStop)
// Note: The concurrent events in messages will be processed one by one.
}
}
def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }
/**
* Called when we are dropping a message. Test cases override this to test message dropping.
* Exposed for testing.
*/
protected def onDrop(message: InboxMessage): Unit = {
logWarning(s"Drop $message because endpoint $endpointName is stopped")
}
/**
* Calls action closure, and calls the endpoint's onError function in the case of exceptions.
*/
private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
def dealWithFatalError(fatal: Throwable): Unit = {
inbox.synchronized {
assert(numActiveThreads > 0, "The number of active threads should be positive.")
// Should reduce the number of active threads before throw the error.
numActiveThreads -= 1
}
logError(s"An error happened while processing message in the inbox for $endpointName", fatal)
throw fatal
}
try action catch {
case NonFatal(e) =>
try endpoint.onError(e) catch {
case NonFatal(ee) =>
if (stopped) {
logDebug("Ignoring error", ee)
} else {
logError("Ignoring error", ee)
}
case fatal: Throwable =>
dealWithFatalError(fatal)
}
case fatal: Throwable =>
dealWithFatalError(fatal)
}
}
// exposed only for testing
def getNumActiveThreads: Int = {
inbox.synchronized {
inbox.numActiveThreads
}
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦