Skip to content

Commit

Permalink
Fix rare server hang under postfn
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 18, 2024
1 parent 4af847b commit a16681e
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 17 deletions.
2 changes: 1 addition & 1 deletion perf/server_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def fn(x):

def client(port):
data = bytearray(size)
client = portal.Client('localhost', port)
client = portal.Client(port)
futures = collections.deque()
durations = collections.deque(maxlen=50)
while True:
Expand Down
2 changes: 1 addition & 1 deletion perf/server_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fn(x):

def client(port):
data = bytearray(size)
client = portal.Client('localhost', port, maxinflight=prefetch + 1)
client = portal.Client(port, maxinflight=prefetch + 1)
futures = collections.deque()
for _ in range(prefetch):
futures.append(client.call('foo', data))
Expand Down
2 changes: 1 addition & 1 deletion perf/socket_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def server(port):

def client(port):
data = [bytearray(size // parts) for _ in range(parts)]
client = portal.ClientSocket('localhost', port)
client = portal.ClientSocket(port)
durations = collections.deque(maxlen=10)
while True:
start = time.perf_counter()
Expand Down
9 changes: 4 additions & 5 deletions perf/socket_proxy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections
import queue
import time

import portal
Expand Down Expand Up @@ -28,24 +27,24 @@ def server(port1):

def proxy(port1, port2):
server = portal.ServerSocket(port2)
client = portal.ClientSocket('localhost', port1)
client = portal.ClientSocket(port1)
addrs = collections.deque()
while True:
try:
addr, data = server.recv(timeout=0.0001)
addrs.append(addr)
client.send(data)
except queue.Empty:
except TimeoutError:
pass
try:
data = client.recv(timeout=0.0001)
server.send(addrs.popleft(), data)
except queue.Empty:
except TimeoutError:
pass

def client(port2):
data = [bytearray(size // parts) for _ in range(parts)]
client = portal.ClientSocket('localhost', port2)
client = portal.ClientSocket(port2)
for _ in range(prefetch):
client.send(*data)
while True:
Expand Down
2 changes: 1 addition & 1 deletion perf/socket_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def server(port):

def client(port):
data = [bytearray(size // parts) for _ in range(parts)]
client = portal.ClientSocket('localhost', port)
client = portal.ClientSocket(port)
for _ in range(prefetch):
client.send(*data)
durations = collections.deque(maxlen=50)
Expand Down
2 changes: 1 addition & 1 deletion portal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '3.4.1'
__version__ = '3.4.2'

import multiprocessing as mp
try:
Expand Down
13 changes: 6 additions & 7 deletions portal/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,12 @@ def _loop(self):
finally:
if not job.postfn:
job.active.release()
if completed:
while self.postfn_inp and self.postfn_inp[0].done():
job = self.postfn_inp.popleft()
data, info = job.result()
postjob = self.postfn_pool.submit(job.postfn, info)
postjob.active = job.active
self.postfn_out.append(postjob)
while self.postfn_inp and self.postfn_inp[0].done():
job = self.postfn_inp.popleft()
data, info = job.result()
postjob = self.postfn_pool.submit(job.postfn, info)
postjob.active = job.active
self.postfn_out.append(postjob)
while self.postfn_out and self.postfn_out[0].done():
postjob = self.postfn_out.popleft()
postjob.active.release()
Expand Down
19 changes: 19 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,25 @@ def postfn(x):
assert completed != list(range(10))
assert logged == list(range(10))

@pytest.mark.parametrize('repeat', range(10))
@pytest.mark.parametrize('Server', SERVERS)
def test_postfn_no_hang(self, repeat, Server):
def wrapper():
port = portal.free_port()
def workfn(x):
return x, x
def postfn(x):
time.sleep(0.01)
server = Server(port, workers=4)
server.bind('fn', workfn, postfn)
server.start(block=False)
client = portal.Client(port)
futures = [client.fn(i) for i in range(20)]
[future.result() for future in futures] # Used to hang here.
client.close()
server.close()
assert portal.Thread(wrapper, start=True).join(timeout=10).exitcode == 0

@pytest.mark.parametrize('repeat', range(5))
@pytest.mark.parametrize('Server', SERVERS)
@pytest.mark.parametrize('workers', (1, 4))
Expand Down

0 comments on commit a16681e

Please sign in to comment.