Skip to content

Commit

Permalink
Small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 12, 2024
1 parent 4cdf974 commit 9bebe67
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
14 changes: 7 additions & 7 deletions portal/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def start(self, block=True):
def close(self, timeout=None):
assert self.started
self.running.clear()
self.server.close(timeout and 0.5 * timeout)
self.batcher.join(timeout and 0.5 * timeout)
self.server.close(timeout)
self.batcher.join(timeout)
self.batcher.kill()

def stats(self):
Expand Down Expand Up @@ -160,12 +160,12 @@ def send_error(addr, reqnum, status, message):
if errors:
raise RuntimeError(message)

outer = server_socket.ServerSocket(outer_port, f'{name}Server', **kwargs)
inner = client.Client('localhost', inner_port, f'{name}Client', **kwargs)
batches = {} # {method: ([addr], [reqnum], structure, [array])}
jobs = []
shutdown = False
try:
outer = server_socket.ServerSocket(outer_port, f'{name}Server', **kwargs)
inner = client.Client('localhost', inner_port, f'{name}Client', **kwargs)
batches = {} # {method: ([addr], [reqnum], structure, [array])}
jobs = []
shutdown = False
while running.is_set() or jobs:
if running.is_set():
maybe_recv(outer, inner, jobs, batches)
Expand Down
1 change: 0 additions & 1 deletion portal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def call(self, method, *data):
name = method.encode('utf-8')
strlen = len(name).to_bytes(8, 'little', signed=False)
sendargs = (reqnum, strlen, name, *packlib.pack(data))
# self.socket.send(reqnum, strlen, name, *packlib.pack(data))
rai = [False]
future = Future(rai)
future.sendargs = sendargs
Expand Down
2 changes: 1 addition & 1 deletion portal/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __exit__(self, *e):
self.close()

def _loop(self):
while self.running or self.jobs or self.postfn_out:
while self.running or self.jobs or self.postfn_inp or self.postfn_out:
while True: # Loop syntax used to break on error.
if not self.running: # Do not accept further requests.
break
Expand Down
9 changes: 9 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,19 @@ def test_postfn_no_backlog(self, repeat, Server, workers):
lock = threading.Lock()
work_calls = [0]
done_calls = [0]

def workfn(x):
with lock:
work_calls[0] += 1
print(work_calls[0], done_calls[0])
assert work_calls[0] <= done_calls[0] + workers + 1
return x, x

def postfn(x):
with lock:
done_calls[0] += 1
time.sleep(0.01)

server = Server(port, workers=workers)
server.bind('fn', workfn, postfn)
server.start(block=False)
Expand All @@ -173,11 +176,14 @@ def postfn(x):
@pytest.mark.parametrize('repeat', range(3))
@pytest.mark.parametrize('Server', SERVERS)
def test_shared_pool(self, repeat, Server):

def slow(x):
time.sleep(0.2)
return x

def fast(x):
return x

port = portal.free_port()
server = Server(port, workers=1)
server.bind('slow', slow)
Expand All @@ -197,11 +203,14 @@ def fast(x):
@pytest.mark.parametrize('repeat', range(3))
@pytest.mark.parametrize('Server', SERVERS)
def test_separate_pools(self, repeat, Server):

def slow(x):
time.sleep(0.1)
return x

def fast(x):
return x

port = portal.free_port()
server = Server(port)
server.bind('slow', slow, workers=1)
Expand Down

0 comments on commit 9bebe67

Please sign in to comment.