diff --git a/portal/batching.py b/portal/batching.py index af603fc..f15586a 100644 --- a/portal/batching.py +++ b/portal/batching.py @@ -53,8 +53,8 @@ def start(self, block=True): def close(self, timeout=None): assert self.started self.running.clear() - self.server.close(timeout and 0.5 * timeout) - self.batcher.join(timeout and 0.5 * timeout) + self.server.close(timeout) + self.batcher.join(timeout) self.batcher.kill() def stats(self): @@ -160,12 +160,12 @@ def send_error(addr, reqnum, status, message): if errors: raise RuntimeError(message) - outer = server_socket.ServerSocket(outer_port, f'{name}Server', **kwargs) - inner = client.Client('localhost', inner_port, f'{name}Client', **kwargs) - batches = {} # {method: ([addr], [reqnum], structure, [array])} - jobs = [] - shutdown = False try: + outer = server_socket.ServerSocket(outer_port, f'{name}Server', **kwargs) + inner = client.Client('localhost', inner_port, f'{name}Client', **kwargs) + batches = {} # {method: ([addr], [reqnum], structure, [array])} + jobs = [] + shutdown = False while running.is_set() or jobs: if running.is_set(): maybe_recv(outer, inner, jobs, batches) diff --git a/portal/client.py b/portal/client.py index fd2d29e..337fec9 100644 --- a/portal/client.py +++ b/portal/client.py @@ -80,7 +80,6 @@ def call(self, method, *data): name = method.encode('utf-8') strlen = len(name).to_bytes(8, 'little', signed=False) sendargs = (reqnum, strlen, name, *packlib.pack(data)) - # self.socket.send(reqnum, strlen, name, *packlib.pack(data)) rai = [False] future = Future(rai) future.sendargs = sendargs diff --git a/portal/server.py b/portal/server.py index 5bb684a..44e995a 100644 --- a/portal/server.py +++ b/portal/server.py @@ -83,7 +83,7 @@ def __exit__(self, *e): self.close() def _loop(self): - while self.running or self.jobs or self.postfn_out: + while self.running or self.jobs or self.postfn_inp or self.postfn_out: while True: # Loop syntax used to break on error. if not self.running: # Do not accept further requests. break diff --git a/tests/test_server.py b/tests/test_server.py index 65b80c8..c4687eb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -151,16 +151,19 @@ def test_postfn_no_backlog(self, repeat, Server, workers): lock = threading.Lock() work_calls = [0] done_calls = [0] + def workfn(x): with lock: work_calls[0] += 1 print(work_calls[0], done_calls[0]) assert work_calls[0] <= done_calls[0] + workers + 1 return x, x + def postfn(x): with lock: done_calls[0] += 1 time.sleep(0.01) + server = Server(port, workers=workers) server.bind('fn', workfn, postfn) server.start(block=False) @@ -173,11 +176,14 @@ def postfn(x): @pytest.mark.parametrize('repeat', range(3)) @pytest.mark.parametrize('Server', SERVERS) def test_shared_pool(self, repeat, Server): + def slow(x): time.sleep(0.2) return x + def fast(x): return x + port = portal.free_port() server = Server(port, workers=1) server.bind('slow', slow) @@ -197,11 +203,14 @@ def fast(x): @pytest.mark.parametrize('repeat', range(3)) @pytest.mark.parametrize('Server', SERVERS) def test_separate_pools(self, repeat, Server): + def slow(x): time.sleep(0.1) return x + def fast(x): return x + port = portal.free_port() server = Server(port) server.bind('slow', slow, workers=1)