Skip to content

Commit

Permalink
Wean off IOLoop.run_sync
Browse files Browse the repository at this point in the history
  • Loading branch information
dwoz authored and Ch3LL committed Mar 4, 2021
1 parent 8fff8d8 commit 4cf62fb
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 24 deletions.
17 changes: 14 additions & 3 deletions salt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,25 @@
sys.stderr.flush()


USE_VENDORED_TORNADO = True


class TornadoImporter:
def find_module(self, module_name, package_path=None):
if module_name.startswith("tornado"):
return self
if USE_VENDORED_TORNADO:
if module_name.startswith("tornado"):
return self
else:
if module_name.startswith("salt.ext.tornado"):
return self
return None

def load_module(self, name):
mod = importlib.import_module("salt.ext.{}".format(name))
if USE_VENDORED_TORNADO:
mod = importlib.import_module("salt.ext.{}".format(name))
else:
# Remove 'salt.ext.' from the module
mod = importlib.import_module(name[9:])
sys.modules[name] = mod
return mod

Expand Down
45 changes: 38 additions & 7 deletions salt/transport/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import salt.ext.tornado
import salt.ext.tornado.concurrent
import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.ext.tornado.netutil
import salt.transport.client
import salt.transport.frame
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(self, socket_path, io_loop=None, payload_handler=None):

# Placeholders for attributes to be populated by method calls
self.sock = None
self.io_loop = io_loop or IOLoop.current()
self.io_loop = io_loop or salt.ext.tornado.ioloop.IOLoop.current()
self._closing = False

def start(self):
Expand Down Expand Up @@ -279,6 +280,7 @@ def __init__(self, socket_path, io_loop=None):
else:
msgpack_kwargs = {"encoding": "utf-8"}
self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
self._connecting_future = None

def connected(self):
return self.stream is not None and not self.stream.closed()
Expand All @@ -287,17 +289,15 @@ def connect(self, callback=None, timeout=None):
"""
Connect to the IPC socket
"""
# pylint: disable=access-member-before-definition
if hasattr(self, "_connecting_future") and not self._connecting_future.done():
if self._connecting_future is not None and not self._connecting_future.done():
future = self._connecting_future
# pylint: enable=access-member-before-definition
else:
if hasattr(self, "_connecting_future"):
if self._connecting_future is not None:
# read previous future result to prevent the "unhandled future exception" error
self._connecting_future.exception() # pylint: disable=E0203
future = salt.ext.tornado.concurrent.Future()
self._connecting_future = future
self._connect(timeout=timeout)
self._connect(timeout)

if callback is not None:

Expand Down Expand Up @@ -360,6 +360,7 @@ def close(self):
return

self._closing = True
self._connecting_future = None

log.debug("Closing %s instance", self.__class__.__name__)

Expand Down Expand Up @@ -556,7 +557,6 @@ def publish(self, msg):
return

pack = salt.transport.frame.frame_msg_ipc(msg, raw_body=True)

for stream in self.streams:
self.io_loop.spawn_callback(self._write, stream, pack)

Expand Down Expand Up @@ -637,6 +637,13 @@ class IPCMessageSubscriber(IPCClient):
package = ipc_subscriber.read_sync()
"""

async_methods = [
"read",
]
close_methods = [
"close",
]

def __init__(self, socket_path, io_loop=None):
super().__init__(socket_path, io_loop=io_loop)
self._read_stream_future = None
Expand Down Expand Up @@ -703,6 +710,30 @@ def _read(self, timeout, callback=None):
raise exc_to_raise # pylint: disable=E0702
raise salt.ext.tornado.gen.Return(ret)

@salt.ext.tornado.gen.coroutine
def read(self, timeout):
"""
Asynchronously read messages and invoke a callback when they are ready.
:param callback: A callback with the received data
"""
if self._saved_data:
res = self._saved_data.pop(0)
raise salt.ext.tornado.gen.Return(res)
while not self.connected():
try:
yield self.connect(timeout=5)
except StreamClosedError:
log.trace(
"Subscriber closed stream on IPC %s before connect",
self.socket_path,
)
yield salt.ext.tornado.gen.sleep(1)
except Exception as exc: # pylint: disable=broad-except
log.error("Exception occurred while Subscriber connecting: %s", exc)
yield salt.ext.tornado.gen.sleep(1)
res = yield self._read(timeout)
raise salt.ext.tornado.gen.Return(res)

def read_sync(self, timeout=None):
"""
Read a message from an IPC socket
Expand Down
83 changes: 69 additions & 14 deletions salt/utils/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,14 @@ def connect_pub(self, timeout=None):
if self._run_io_loop_sync:
with salt.utils.asynchronous.current_ioloop(self.io_loop):
if self.subscriber is None:
self.subscriber = salt.transport.ipc.IPCMessageSubscriber(
self.puburi, io_loop=self.io_loop
self.subscriber = salt.utils.asynchronous.SyncWrapper(
salt.transport.ipc.IPCMessageSubscriber,
args=(self.puburi,),
kwargs={"io_loop": self.io_loop},
loop_kwarg="io_loop",
)
try:
self.io_loop.run_sync(
lambda: self.subscriber.connect(timeout=timeout)
)
self.subscriber.connect(timeout=timeout)
self.cpub = True
except Exception: # pylint: disable=broad-except
pass
Expand Down Expand Up @@ -402,11 +403,14 @@ def connect_pull(self, timeout=1):
if self._run_io_loop_sync:
with salt.utils.asynchronous.current_ioloop(self.io_loop):
if self.pusher is None:
self.pusher = salt.transport.ipc.IPCMessageClient(
self.pulluri, io_loop=self.io_loop
self.pusher = salt.utils.asynchronous.SyncWrapper(
salt.transport.ipc.IPCMessageClient,
args=(self.pulluri,),
kwargs={"io_loop": self.io_loop},
loop_kwarg="io_loop",
)
try:
self.io_loop.run_sync(lambda: self.pusher.connect(timeout=timeout))
self.pusher.connect(timeout=timeout)
self.cpush = True
except Exception: # pylint: disable=broad-except
pass
Expand Down Expand Up @@ -547,11 +551,9 @@ def _get_event(self, wait, tag, match_func=None, no_block=False):
# Trigger that at least a single iteration has gone through
run_once = True
try:
# salt.ext.tornado.ioloop.IOLoop.run_sync() timeouts are in seconds.
# IPCMessageSubscriber.read_sync() uses this type of timeout.
if not self.cpub and not self.connect_pub(timeout=wait):
break
raw = self.subscriber.read_sync(timeout=wait)
raw = self.subscriber.read(timeout=wait)
if raw is None:
break
mtag, data = self.unpack(raw, self.serial)
Expand Down Expand Up @@ -674,7 +676,7 @@ def get_event_noblock(self):
if not self.cpub:
if not self.connect_pub():
return None
raw = self.subscriber.read_sync(timeout=0)
raw = self.subscriber._read(timeout=0)
if raw is None:
return None
mtag, data = self.unpack(raw, self.serial)
Expand All @@ -690,7 +692,7 @@ def get_event_block(self):
if not self.cpub:
if not self.connect_pub():
return None
raw = self.subscriber.read_sync(timeout=None)
raw = self.subscriber._read(timeout=None)
if raw is None:
return None
mtag, data = self.unpack(raw, self.serial)
Expand All @@ -708,6 +710,59 @@ def iter_events(self, tag="", full=False, match_type=None, auto_reconnect=False)
continue
yield data

@salt.ext.tornado.gen.coroutine
def fire_event_async(self, data, tag, cb=None, timeout=1000):
"""
Send a single event into the publisher with payload dict "data" and
event identifier "tag"
The default is 1000 ms
"""
if self.opts.get("subproxy", False):
data["proxy_target"] = self.opts["id"]

if not str(tag): # no empty tags allowed
raise ValueError("Empty tag.")

if not isinstance(data, MutableMapping): # data must be dict
raise ValueError("Dict object expected, not '{}'.".format(data))

if not self.cpush:
if timeout is not None:
timeout_s = float(timeout) / 1000
else:
timeout_s = None
if not self.connect_pull(timeout=timeout_s):
return False

data["_stamp"] = datetime.datetime.utcnow().isoformat()

tagend = TAGEND
# Since the pack / unpack logic here is for local events only,
# it is safe to change the wire protocol. The mechanism
# that sends events from minion to master is outside this
# file.
dump_data = self.serial.dumps(data, use_bin_type=True)

serialized_data = salt.utils.dicttrim.trim_dict(
dump_data,
self.opts["max_event_size"],
is_msgpacked=True,
use_bin_type=True,
)
log.debug("Sending event: tag = %s; data = %s", tag, data)
event = b"".join(
[
salt.utils.stringutils.to_bytes(tag),
salt.utils.stringutils.to_bytes(tagend),
serialized_data,
]
)
msg = salt.utils.stringutils.to_bytes(event, "utf-8")
ret = yield self.pusher.send(msg)
if cb is not None:
cb(ret)

def fire_event(self, data, tag, timeout=1000):
"""
Send a single event into the publisher with payload dict "data" and
Expand Down Expand Up @@ -759,7 +814,7 @@ def fire_event(self, data, tag, timeout=1000):
if self._run_io_loop_sync:
with salt.utils.asynchronous.current_ioloop(self.io_loop):
try:
self.io_loop.run_sync(lambda: self.pusher.send(msg))
self.pusher.send(msg)
except Exception as ex: # pylint: disable=broad-except
log.debug(ex)
raise
Expand Down
115 changes: 115 additions & 0 deletions tests/pytests/functional/transport/ipc/test_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import sys

import attr
import pytest
import salt.ext.tornado.gen
import salt.transport.client
import salt.transport.ipc
import salt.transport.server
from salt.ext.tornado import locks

pytestmark = [
# Windows does not support POSIX IPC
pytest.mark.skip_on_windows,
pytest.mark.skipif(
sys.version_info < (3, 6), reason="The IOLoop blocks under Py3.5 on these tests"
),
]


@attr.s(frozen=True, slots=True)
class PayloadHandler:
payloads = attr.ib(init=False, default=attr.Factory(list))

async def handle_payload(self, payload, reply_func):
self.payloads.append(payload)
await reply_func(payload)

def __enter__(self):
return self

def __exit__(self, *args):
self.payloads.clear()


@attr.s(frozen=True, slots=True)
class IPCTester:
io_loop = attr.ib()
socket_path = attr.ib()
publisher = attr.ib()
subscriber = attr.ib()
payloads = attr.ib(default=attr.Factory(list))
payload_ack = attr.ib(default=attr.Factory(locks.Condition))

@subscriber.default
def _subscriber_default(self):
return salt.transport.ipc.IPCMessageSubscriber(
self.socket_path, io_loop=self.io_loop,
)

@publisher.default
def _publisher_default(self):
return salt.transport.ipc.IPCMessagePublisher(
{"ipc_write_buffer": 0}, self.socket_path, io_loop=self.io_loop,
)

async def handle_payload(self, payload, reply_func):
self.payloads.append(payload)
await reply_func(payload)
self.payload_ack.notify()

def new_client(self):
return IPCTester(
io_loop=self.io_loop,
socket_path=self.socket_path,
server=self.server,
payloads=self.payloads,
payload_ack=self.payload_ack,
)

async def publish(self, payload, timeout=60):
self.publisher.publish(payload)

async def read(self, timeout=60):
ret = await self.subscriber.read(timeout)
return ret

def __enter__(self):
self.publisher.start()
self.io_loop.add_callback(self.subscriber.connect)
return self

def __exit__(self, *args):
self.subscriber.close()
self.publisher.close()


@pytest.fixture
def ipc_socket_path(tmp_path):
_socket_path = tmp_path / "ipc-test.ipc"
try:
yield _socket_path
finally:
if _socket_path.exists():
_socket_path.unlink()


@pytest.fixture
def channel(io_loop, ipc_socket_path):
_ipc_tester = IPCTester(io_loop=io_loop, socket_path=str(ipc_socket_path))
with _ipc_tester:
yield _ipc_tester


async def test_basic_send(channel):
msg = {"foo": "bar", "stop": True}
# XXX: IPCClient connect and connected methods need to be cleaned up as
# this should not be needed.
while not channel.subscriber._connecting_future.done():
await salt.ext.tornado.gen.sleep(0.01)
while not channel.subscriber.connected():
await salt.ext.tornado.gen.sleep(0.01)
assert channel.subscriber.connected()
await channel.publish(msg)
ret = await channel.read()
assert ret == msg

1 comment on commit 4cf62fb

@szjur
Copy link

@szjur szjur commented on 4cf62fb May 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dwoz Does the only way to wean off IOLoop.run_sync() lead through breaking get_event_block()? _read() will actually return a future which then will break in self.unpack() where it tries to access raw.partition. get_event() uses _get_event() which uses read() instead of _read() and the former somehow makes it work. I'm not as proficient in the old-style yield-based async code so I can't fully explain why but IPCMessageSubscriber.read() yields the result of _read() which combined with the fact it is marked as a coroutine (even though it is called directly from non-async code) makes the result contain the actual result of the future. I hope it makes sense, anyway you rendered get_event_block() useless with this change, at least in 3003.4.
My colleague opened #62015 for this some time ago but it got little attention. There's some simple code there showing that it just breaks get_event_block(). That code used to work from 2016.* through 3000 with no problems.

I understand you wean off IOLoop.run_sync() which was used IPCMessageSubscriber.read_sync() and you wrote that new read() function which using the aforementioned "hocus pocus" makes coroutines called from synchronous code work but then you just can't start calling that _read() directly because that won't work. My gut feeling is that using that IPCMessageSubscriber.read_sync() to run the async coroutine _read() was the right way to do it but if it works the way that new read() func does it that's also fine, just as I said it does not work if _read() is called directly as it now is from get_event_block() and also get_event_noblock(), which is most likely broken too, we just don't use that one so we don't care xD

Obviously it doesn't seem to make sense you create that new read() function and then replace read_sync() for read() in _get_event() while replacing the same read_sync() for the existing function _read() in two other places that can't just possibly work neither in 3003 nor 3000. That read_sync() wrapped that _read() in IOLoop.run_sync() for a reason right? And you seem to have discerned this reason partly by writing that read() function as a replacement.

I don't think doing it the way you did will even be portable to new-style asyncio - you can't mix sync/async code easily there like that - I don't think it's possible to call a coroutine there just like a normal function and get the result, which ultimately is the case in your code (_get_event() calls your new read() function which is a coroutine). As I said I haven't fully figured out why wrapping that _read() coroutine in your read() coroutine makes it return the actual result instead of a Future object (which is what _read() returns and hence the bug) but that seems to be the case. I tried to do that in native Python3 async but that won't work - you'll always get a coroutine object plus a warning eventually that it was never awaited. It normally takes asyncio.run() to run a coroutine which is some equivalent of the IOLoop.run_sync() of which the proces of weaning off this PR started.

@Ch3LL are you also happy with this PR introducing bugs like this?

Please sign in to comment.