diff --git a/rpyc/utils/server.py b/rpyc/utils/server.py index 93dd38bf..17df6e24 100644 --- a/rpyc/utils/server.py +++ b/rpyc/utils/server.py @@ -13,14 +13,10 @@ from rpyc.utils.registry import UDPRegistryClient from rpyc.utils.authenticators import AuthenticationError from rpyc.lib import safe_import +from rpyc.lib.compat import poll signal = safe_import("signal") -class ThreadPoolFull(Exception): - """raised when the ThreadPoolServer's is overloaded (all threads in the - thread pool are used)""" - pass - class Server(object): """Base server implementation @@ -248,8 +244,17 @@ def _accept_method(self, sock): class ThreadPoolServer(Server): """This server is threaded like the ThreadedServer but reuses threads so that recreation is not necessary for each request. The pool of threads has a fixed - size that can be set with the 'nbThreads' argument. Otherwise, the default is 20""" + size that can be set with the 'nbThreads' argument. The default size is 20. + The server dispatches request to threads by batch, that is a given thread may process + up to request_batch_size requests from the same connection in one go, before it goes to + the next connection with pending requests. By default, self.request_batch_size + is set to 10 and it can be overwritten in the constructor arguments. + Contributed by *@sponce* + + Parameters: see :class:`Server` + """ + def __init__(self, *args, **kwargs): '''Initializes a ThreadPoolServer. In particular, instantiate the thread pool.''' # get the number of threads in the pool @@ -257,44 +262,182 @@ def __init__(self, *args, **kwargs): if 'nbThreads' in kwargs: nbthreads = kwargs['nbThreads'] del kwargs['nbThreads'] + # get the request batch size + self.request_batch_size = 10 + if 'requestBatchSize' in kwargs: + self.request_batch_size = kwargs['requestBatchSize'] + del kwargs['requestBatchSize'] # init the parent Server.__init__(self, *args, **kwargs) - # create a queue where requests will be pending until a thread is ready - self._client_queue = Queue.Queue(nbthreads) + # a queue of connections having somethign to process + self._active_connection_queue = Queue.Queue() # declare the pool as already active self.active = True - # setup the thread pool - for i in range(nbthreads): - t = threading.Thread(target = self._authenticate_and_serve_clients, args=(self._client_queue,)) + # setup the thread pool for handling requests + self.workers = [] + for _ in range(nbthreads): + t = threading.Thread(target = self._serve_clients) + t.setName('ThreadPoolWorker') t.daemon = True t.start() - - def _authenticate_and_serve_clients(self, queue): - '''Main method run by the threads of the thread pool. It gets work from the - internal queue and calls the _authenticate_and_serve_client method''' + self.workers.append(t) + # a polling object to be used be the polling thread + self.poll_object = poll() + # a dictionary fd -> connection + self.fd_to_conn = {} + # setup a thread for polling inactive connections + self.polling_thread = threading.Thread(target = self._poll_inactive_clients) + self.polling_thread.setName('PollingThread') + self.polling_thread.daemon = True + self.polling_thread.start() + + def close(self): + '''closes a ThreadPoolServer. In particular, joins the thread pool.''' + # close parent server + Server.close(self) + # stop producer thread + self.polling_thread.join() + # cleanup thread pool : first fill the pool with None fds so that all threads exit + # the blocking get on the queue of active connections. Then join the threads + for _ in range(len(self.workers)): + self._active_connection_queue.put(None) + for w in self.workers: + w.join() + + def _remove_from_inactive_connection(self, fd): + '''removes a connection from the set of inactive ones''' + # unregister the connection in the polling object + try: + self.poll_object.unregister(fd) + except KeyError: + # the connection has already been unregistered + pass + + def _drop_connection(self, fd): + '''removes a connection by closing it and removing it from internal structs''' + # cleanup fd_to_conn dictionnary + try: + conn = self.fd_to_conn[fd] + del self.fd_to_conn[fd] + except KeyError: + # the active connection has already been removed + pass + # close connection + conn.close() + + def _add_inactive_connection(self, fd): + '''adds a connection to the set of inactive ones''' + self.poll_object.register(fd, "rw") + + def _handle_poll_result(self, connlist): + '''adds a connection to the set of inactive ones''' + for fd, evt in connlist: + try: + # remove connection from the inactive ones + self._remove_from_inactive_connection(fd) + # Is it an error ? + if "e" in evt or "n" in evt or "h" in evt: + # it was an error, connection was closed. Do the same on our side + self._drop_connection(fd) + else: + # connection has data, let's add it to the active queue + self._active_connection_queue.put(fd) + except KeyError: + # the connection has already been dropped. Give up + pass + + def _poll_inactive_clients(self): + '''Main method run by the polling thread of the thread pool. + Check whether inactive clients have become active''' while self.active: try: - sock = queue.get(True, 1) - self._authenticate_and_serve_client(sock) + # the actual poll, with a timeout of 1s so that we can exit in case + # we re not active anymore + active_clients = self.poll_object.poll(1) + # for each client that became active, put them in the active queue + self._handle_poll_result(active_clients) + except Exception, e: + # "Caught exception in Worker thread" message + self.logger.warning("failed to poll clients, caught exception : %s", str(e)) + # wait a bit so that we do not loop too fast in case of error + time.sleep(.2) + + def _serve_requests(self, fd): + '''Serves requests from the given connection and puts it back to the appropriate queue''' + # serve a maximum of RequestBatchSize requests for this connection + for _ in range(self.request_batch_size): + try: + if not self.fd_to_conn[fd].poll(): # note that poll serves the request + # we could not find a request, so we put this connection back to the inactive set + self._add_inactive_connection(fd) + return + except EOFError: + # the connection has been closed by the remote end. Close it on our side and return + self._drop_connection(fd) + return + except Exception: + # put back the connection to active queue in doubt and raise the exception to the upper level + self._active_connection_queue.put(fd) + raise + # we've processed the maximum number of requests. Put back the connection in the active queue + self._active_connection_queue.put(fd) + + def _serve_clients(self): + '''Main method run by the processing threads of the thread pool. + Loops forever, handling requests read from the connections present in the active_queue''' + while self.active: + try: + # note that we do not use a timeout here. This is because the implementation of + # the timeout version performs badly. So we block forever, and exit by filling + # the queue with None fds + fd = self._active_connection_queue.get(True) + # fd may be None (case where we want to exit the blocking get to close the service) + if fd: + # serve the requests of this connection + self._serve_requests(fd) except Queue.Empty: # we've timed out, let's just retry. We only use the timeout so that this # thread can stop even if there is nothing in the queue pass except Exception, e: # "Caught exception in Worker thread" message - self.logger.info("failed to serve client, caught exception : %s", str(e)) + self.logger.warning("failed to serve client, caught exception : %s", str(e)) # wait a bit so that we do not loop too fast in case of error time.sleep(.2) - + + def _authenticate_and_build_connection(self, sock): + '''Authenticate a client and if it succees, wraps the socket in a connection object. + Note that this code is cut and paste from the rpyc internals and may have to be + changed if rpyc evolves''' + # authenticate + if self.authenticator: + h, p = sock.getpeername() + try: + sock, credentials = self.authenticator(sock) + except AuthenticationError: + self.logger.info("%s:%s failed to authenticate, rejecting connection", h, p) + return None + else: + credentials = None + # build a connection + h, p = sock.getpeername() + config = dict(self.protocol_config, credentials=credentials, connid="%s:%d"%(h, p)) + return Connection(self.service, Channel(SocketStream(sock)), config=config) + def _accept_method(self, sock): '''Implementation of the accept method : only pushes the work to the internal queue. In case the queue is full, raises an AsynResultTimeout error''' try: - # try to put the request in the queue - self._client_queue.put_nowait(sock) - except Queue.Full: - # queue was full, reject request - raise ThreadPoolFull("server is overloaded") + # authenticate and build connection object + conn = self._authenticate_and_build_connection(sock) + # put the connection in the active queue + if conn: + fd = conn.fileno() + self.fd_to_conn[fd] = conn + self._add_inactive_connection(fd) + self.clients.clear() + except Exception, e: + self.logger.warning("failed to serve client, caught exception : %s", str(e)) class ForkingServer(Server):