Skip to content

Commit

Permalink
Added connection pooling.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin authored and aarondav committed Oct 10, 2014
1 parent d23ed7b commit b2f3281
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin

private[this] val serverAddr = cf.channel().remoteAddress().toString

def isActive: Boolean = cf.channel().isActive

/**
* Ask the remote server for a sequence of blocks, and execute the callback.
*
Expand All @@ -55,7 +57,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin
def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = {
var startTime: Long = 0
logTrace {
startTime = System.nanoTime
startTime = System.nanoTime()
s"Sending request $blockIds to $serverAddr"
}

Expand All @@ -67,7 +69,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin
override def operationComplete(future: ChannelFuture): Unit = {
if (future.isSuccess) {
logTrace {
val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
val timeTaken = (System.nanoTime() - startTime).toDouble / 1000000
s"Sending request $blockIds to $serverAddr took $timeTaken ms"
}
} else {
Expand All @@ -84,9 +86,6 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin
})
}

def waitForClose(): Unit = {
cf.channel().closeFuture().sync()
}

/** Close the connection. This does NOT block till the connection is closed. */
def close(): Unit = cf.channel().close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.network.netty

import java.util.concurrent.TimeoutException
import java.util.concurrent.{ConcurrentHashMap, TimeoutException}

import io.netty.bootstrap.Bootstrap
import io.netty.buffer.PooledByteBufAllocator
Expand All @@ -35,20 +35,26 @@ import org.apache.spark.util.Utils


/**
* Factory for creating [[BlockClient]] by using createClient. This factory reuses
* the worker thread pool for Netty.
* Factory for creating [[BlockClient]] by using createClient.
*
* The factory maintains a connection pool to other hosts and should return the same [[BlockClient]]
* for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s.
*/
private[netty]
class BlockClientFactory(val conf: NettyConfig) {

def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))

/** A thread factory so the threads are named (for debugging). */
private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client")
private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client")

/** Socket channel type, initialized by [[init]] depending ioMode. */
private[this] var socketChannelClass: Class[_ <: Channel] = _

/** The following two are instantiated by the [[init]] method, depending ioMode. */
private[netty] var socketChannelClass: Class[_ <: Channel] = _
private[netty] var workerGroup: EventLoopGroup = _
/** Thread pool shared by all clients. */
private[this] var workerGroup: EventLoopGroup = _

private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient]

// The encoders are stateless and can be shared among multiple clients.
private[this] val encoder = new ClientRequestEncoder
Expand Down Expand Up @@ -88,6 +94,16 @@ class BlockClientFactory(val conf: NettyConfig) {
* Concurrency: This method is safe to call from multiple threads.
*/
def createClient(remoteHost: String, remotePort: Int): BlockClient = {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
val cachedClient = connectionPool.get((remoteHost, remotePort))
if (cachedClient != null && cachedClient.isActive) {
return cachedClient
}

// There is a chance two threads are creating two different clients connecting to the same host.
// But that's probably ok ...

val handler = new BlockClientHandler

val bootstrap = new Bootstrap
Expand Down Expand Up @@ -118,10 +134,20 @@ class BlockClientFactory(val conf: NettyConfig) {
s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)")
}

new BlockClient(cf, handler)
val client = new BlockClient(cf, handler)
connectionPool.put((remoteHost, remotePort), client)
client
}

/** Close all connections in the connection pool, and shutdown the worker thread pool. */
def stop(): Unit = {
val iter = connectionPool.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
entry.getValue.close()
connectionPool.remove(entry.getKey)
}

if (workerGroup != null) {
workerGroup.shutdownGracefully()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.netty

import scala.concurrent.{Await, future}
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global

import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark.SparkConf


class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll {

private val conf = new SparkConf
private var server1: BlockServer = _
private var server2: BlockServer = _

override def beforeAll() {
server1 = new BlockServer(new NettyConfig(conf), null)
server2 = new BlockServer(new NettyConfig(conf), null)
}

override def afterAll() {
if (server1 != null) {
server1.stop()
}
if (server2 != null) {
server2.stop()
}
}

test("BlockClients created are active and reused") {
val factory = new BlockClientFactory(conf)
val c1 = factory.createClient(server1.hostName, server1.port)
val c2 = factory.createClient(server1.hostName, server1.port)
val c3 = factory.createClient(server2.hostName, server2.port)
assert(c1.isActive)
assert(c3.isActive)
assert(c1 === c2)
assert(c1 !== c3)
factory.stop()
}

test("never return inactive clients") {
val factory = new BlockClientFactory(conf)
val c1 = factory.createClient(server1.hostName, server1.port)
c1.close()

// Block until c1 is no longer active
val f = future {
while (c1.isActive) {
Thread.sleep(10)
}
}
Await.result(f, 3 seconds)
assert(!c1.isActive)

// Create c2, which should be different from c1
val c2 = factory.createClient(server1.hostName, server1.port)
assert(c1 !== c2)
factory.stop()
}

test("BlockClients are close when BlockClientFactory is stopped") {
val factory = new BlockClientFactory(conf)
val c1 = factory.createClient(server1.hostName, server1.port)
val c2 = factory.createClient(server2.hostName, server2.port)
assert(c1.isActive)
assert(c2.isActive)
factory.stop()
assert(!c1.isActive)
assert(!c2.isActive)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.network._

class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester {

/** Helper method to get num. outstanding requests from a private field using reflection. */
private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = {
val f = handler.getClass.getDeclaredField(
"org$apache$spark$network$netty$BlockClientHandler$$outstandingRequests")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import scala.collection.JavaConversions._
import io.netty.buffer.Unpooled

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.Span
import org.scalatest.time.Seconds

import org.apache.spark.SparkConf
import org.apache.spark.network._
Expand Down Expand Up @@ -156,4 +159,10 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference))
assert(failBlockIds === Set("random-block"))
}

test("shutting down server should also close client") {
val client = clientFactory.createClient(server.hostName, server.port)
server.stop()
eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) }
}
}

0 comments on commit b2f3281

Please sign in to comment.