前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住给大家分享一下。点击跳转到网站:https://www.captainai.net/dongkelun
前言
学习记录一下Spark源码中的RPC,本文基于Spark3源码进行学习研究
三个重要的类
RpcEnv、RpcEndpoint、RpcEndpointRef主要是想搞懂这三个之间的关系
子类
Spark 1.6.3中默认使用了Netty作为底层的实现,但Akka的依赖依然存在;而Spark 2.1.0中的底层实现只有Netty,这样用户可以方便的使用不同版本的Akka或者将来某种更好的底层实现
NettyRpcEnv1
2
3
4
5
6private[netty] class NettyRpcEnv(
val conf: SparkConf,
javaSerializerInstance: JavaSerializerInstance,
host: String,
securityManager: SecurityManager,
numUsableCores: Int) extends RpcEnv(conf) with Logging {
NettyRpcEndpointRef1
2
3
4private[netty] class NettyRpcEndpointRef(
@transient private val conf: SparkConf,
private val endpointAddress: RpcEndpointAddress,
@transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
RpcEndpoint 有两个子类trait(特质)
1 |
|
真正的实现是在具体的内部类里,比如
CoarseGrainedSchedulerBackend 里面的 DriverEndpoint1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17class DriverEndpoint extends IsolatedRpcEndpoint with Logging {
override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv
protected val addressToExecutorId = new HashMap[RpcAddress, String]
// Spark configuration sent to executors. This is a lazy val so that subclasses of the
// scheduler can modify the SparkConf object before this view is created.
private lazy val sparkProperties = scheduler.sc.conf.getAll
.filter { case (k, _) => k.startsWith("spark.") }
.toSeq
private val logUrlHandler: ExecutorLogUrlHandler = new ExecutorLogUrlHandler(
conf.get(UI.CUSTOM_EXECUTOR_LOG_URL))
override def onStart(): Unit = {
......
也可以是匿名内部类如在测试RpcEnvSuite里的1
2
3
4
5
6
7
8
9
10val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint {
override val rpcEnv = env
override def receive = {
case msg: String => {
println(msg)
message = msg
}
}
})
注册流程
创建RpcEndpoint(子类)
1 | new DriverEndpoint() |
创建RpcEnv
实际上是通过NettyRpcEnvFactory的create方法创建并返回NettyRpcEnv
SparkContext初始化时:
调用顺序createSparkEnv->SparkEnv.createDriverEnv->create->RpcEnv.create
1 | // Create the Spark execution environment (cache, map output tracker, etc) |
RpcEnv.create方法:1
2
3
4
5
6
7
8
9
10
11
12
13def create(
name: String,
bindAddress: String,
advertiseAddress: String,
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean): RpcEnv = {
val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
numUsableCores, clientMode)
new NettyRpcEnvFactory().create(config)
}
实际调用的NettyRpcEnvFactory的create方法返回NettyRpcEnv
1 | def create(config: RpcEnvConfig): RpcEnv = { |
下面是测试类RpcEnvSuite的创建方法
1 | override def createRpcEnv( |
向RpcEnv注册RpcEndpoint并返回RpcEndpointRef
1 | rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint()) |
实际上是调用NettyRpcEnv的setupEndpoint方法
1
2
3override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
dispatcher.registerRpcEndpoint(name, endpoint)
}
然后调用dispatcher.registerRpcEndpoint创建并返回NettyRpcEndpointRef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
val addr = RpcEndpointAddress(nettyEnv.address, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
if (endpoints.containsKey(name)) {
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
}
// This must be done before assigning RpcEndpoint to MessageLoop, as MessageLoop sets Inbox be
// active when registering, and endpointRef must be put into endpointRefs before onStart is
// called.
endpointRefs.put(endpoint, endpointRef)
var messageLoop: MessageLoop = null
try {
messageLoop = endpoint match {
case e: IsolatedRpcEndpoint =>
new DedicatedMessageLoop(name, e, this)
case _ =>
sharedLoop.register(name, endpoint)
sharedLoop
}
endpoints.put(name, messageLoop)
} catch {
case NonFatal(e) =>
endpointRefs.remove(endpoint)
throw e
}
}
endpointRef
}
RpcEndpointRef.send
实际上调用NettyRpcEndpointRef.send1
2
3
4override def send(message: Any): Unit = {
require(message != null, "Message is null")
nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
}
然后调用NettyRpcEnv.send1
2
3
4
5
6
7
8
9
10
11
12
13
14private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
try {
dispatcher.postOneWayMessage(message)
} catch {
case e: RpcEnvStoppedException => logDebug(e.getMessage)
}
} else {
// Message to a remote RPC endpoint.
postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
}
}
这里会根据address判断是否本地还是远程调用,下面按本地模式进行研究,即调用Dispatcher的postOneWayMessage方法
1 |
|
1 | private def postMessage( |
这里的loop是在上面的registerRpcEndpoint方法中根据判断endpoint是否是IsolatedRpcEndpoint的子类来区分成两种1
2
3
4
5
6
7
8
9
10var messageLoop: MessageLoop = null
try {
messageLoop = endpoint match {
case e: IsolatedRpcEndpoint =>
new DedicatedMessageLoop(name, e, this)
case _ =>
sharedLoop.register(name, endpoint)
sharedLoop
}
endpoints.put(name, messageLoop)
这里先按sharedLoop来进行分析,其中1
private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, numUsableCores)
loop.post:1
2
3
4
5override def post(endpointName: String, message: InboxMessage): Unit = {
val inbox = endpoints.get(endpointName)
inbox.post(message)
setActive(inbox)
}
向inbox插入一条InboxMessage,并将该inbox设置为active
receive
SharedMessageLoop中1
2
3
4
5
6
7
8override protected val threadpool: ThreadPoolExecutor = {
val numThreads = getNumOfThreads(conf)
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(receiveLoopRunnable)
}
pool
}
1 | protected val receiveLoopRunnable = new Runnable() { |
1 | private def receiveLoop(): Unit = { |
循环调用inbox.process1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/**
* Process stored messages.
*/
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
} else {
return
}
}
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
case e: Throwable =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
assert(activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")
case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
numActiveThreads -= 1
return
}
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}
inbox.process会匹配消息类型,然后根据消息类型,执行endpoint.receive方法,这里的消息类型是OneWayMessage即1
2
3
4case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
最后看一下DedicatedMessageLoop,其实DedicatedMessageLoop和SharedMessageLoop差不多1
2
3
4
5
6
7
8private val inbox = new Inbox(name, endpoint)
// Mark active to handle the OnStart message.
setActive(inbox)
override def post(endpointName: String, message: InboxMessage): Unit = {
require(endpointName == name)
inbox.post(message)
setActive(inbox)
}
同样有一个线程池去调用receiveLoopRunnable,下面的就和SharedMessageLoop一样了1
2
3
4
5
6
7
8
9override protected val threadpool = if (endpoint.threadCount() > 1) {
ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name", endpoint.threadCount())
} else {
ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name")
}
(1 to endpoint.threadCount()).foreach { _ =>
threadpool.submit(receiveLoopRunnable)
}
onStart
这里提一下onStart方法,因为注释中写到1
2
3
4
5
6
7
8
9
10
11
12
13
14 * An end point for the RPC that defines what functions to trigger given a message.
*
* It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
*
* The life-cycle of an endpoint is:
*
* { constructor -> onStart -> receive* -> onStop}
......
/**
* Invoked before [[RpcEndpoint]] starts to handle any message.
*/
def onStart(): Unit = {
// By default, do nothing.
}
即一个RpcEndpoint生命周期:构建->onStart→receive→onStop,那么onstart方法是在receive方法之前先执行的,但是在代码这个类中无法看出如何最新执行,且网上的
博客也没有写清楚如何执行,然后我就带着疑问仔细调试和阅读了一下源码,发现是这样的:
首先在上面的inbox初始化时,首先会执行下面的代码1
2
3
4
5
6
7 "this") (
protected val messages = new java.util.LinkedList[InboxMessage]()
// OnStart should be the first message to process
inbox.synchronized {
messages.add(OnStart)
}
那么在上面提到的process方法中messages会首先取出OnStart,而在模式匹配时匹配到OnStart,就会执行OnStart方法1
2
3
4
5
6
7
8
9
10
11message = messages.poll()
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
在CoarseGrainedSchedulerBackend里的DriverEndpoint就重写了onStart方法来提前执行一些准备工作1
2
3
4
5
6
7
8
9override def onStart(): Unit = {
// Periodically revive offers to allow delay scheduling to work
// 调度程序为了运行任务而重新提供work资源的间隔长度。
val reviveIntervalMs = conf.get(SCHEDULER_REVIVE_INTERVAL).getOrElse(1000L)
//每隔1秒,给自己发一个ReviveOffers,发给receive函数
reviveThread.scheduleAtFixedRate(() => Utils.tryLogNonFatalError {
Option(self).foreach(_.send(ReviveOffers))
}, 0, reviveIntervalMs, TimeUnit.MILLISECONDS)
}
测试Demo
1 | import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} |
运行结果:1
2start hello endpoint
hello