diff --git a/modin/experimental/cloud/rpyc_proxy.py b/modin/experimental/cloud/rpyc_proxy.py index 4236b8cb2df..fa6281f7e74 100644 --- a/modin/experimental/cloud/rpyc_proxy.py +++ b/modin/experimental/cloud/rpyc_proxy.py @@ -15,6 +15,7 @@ import time import threading import collections +import os import rpyc from rpyc.lib.compat import pickle @@ -34,6 +35,7 @@ def _tuplize(arg): """turns any sequence or iterator into a flat tuple""" return tuple(arg) +_TRACE_RPYC = os.environ.get('MODIN_TRACE_RPYC', '').title() == 'True' _msg_to_name = collections.defaultdict(list) for name in dir(consts): @@ -52,51 +54,6 @@ def __init__(self, *a, **kw): self._remote_dumps = None self._remote_tuplize = None - self.logLock = threading.RLock() - self.timings = {} - with open("rpyc-trace.log", "a") as out: - out.write(f"------------[new trace at {time.asctime()}]----------\n") - self.logfiles = set(["rpyc-trace.log"]) - - def _send(self, msg, seq, args): - """tracing only""" - str_args = str(args).replace("\r", "").replace("\n", "\tNEWLINE\t") - if msg == consts.MSG_REQUEST: - handler, _ = args - str_handler = f":req={_msg_to_name['HANDLE'][handler]}" - else: - str_handler = "" - with self.logLock: - for logfile in self.logfiles: - with open(logfile, "a") as out: - out.write( - f"send:msg={_msg_to_name['MSG'][msg]}:seq={seq}{str_handler}:args={str_args}\n" - ) - self.timings[seq] = time.time() - return super()._send(msg, seq, args) - - def _dispatch(self, data): - """tracing only""" - got1 = time.time() - try: - return super()._dispatch(data) - finally: - got2 = time.time() - msg, seq, args = brine.load(data) - sent = self.timings.pop(seq, got1) - if msg == consts.MSG_REQUEST: - handler, args = args - str_handler = f":req={_msg_to_name['HANDLE'][handler]}" - else: - str_handler = "" - str_args = str(args).replace("\r", "").replace("\n", "\tNEWLINE\t") - with self.logLock: - for logfile in self.logfiles: - with open(logfile, "a") as out: - out.write( - f"recv:timing={got1 - sent}+{got2 - got1}:msg={_msg_to_name['MSG'][msg]}:seq={seq}{str_handler}:args={str_args}\n" - ) - def __wrap(self, local_obj): while True: # unwrap magic wrappers first; keep unwrapping in case it's a wrapper-in-a-wrapper @@ -275,6 +232,61 @@ def _box(self, obj): break return super()._box(obj) + def _init_deliver(self): + self._remote_batch_loads = self.modules[ + "modin.experimental.cloud.rpyc_proxy" + ]._batch_loads + self._remote_dumps = self.modules["rpyc.lib.compat"].pickle.dumps + self._remote_tuplize = self.modules[ + "modin.experimental.cloud.rpyc_proxy" + ]._tuplize + +class TracingWrappingConnection(WrappingConnection): + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + self.logLock = threading.RLock() + self.timings = {} + with open("rpyc-trace.log", "a") as out: + out.write(f"------------[new trace at {time.asctime()}]----------\n") + self.logfiles = set(["rpyc-trace.log"]) + + def _send(self, msg, seq, args): + str_args = str(args).replace("\r", "").replace("\n", "\tNEWLINE\t") + if msg == consts.MSG_REQUEST: + handler, _ = args + str_handler = f":req={_msg_to_name['HANDLE'][handler]}" + else: + str_handler = "" + with self.logLock: + for logfile in self.logfiles: + with open(logfile, "a") as out: + out.write( + f"send:msg={_msg_to_name['MSG'][msg]}:seq={seq}{str_handler}:args={str_args}\n" + ) + self.timings[seq] = time.time() + return super()._send(msg, seq, args) + + def _dispatch(self, data): + """tracing only""" + got1 = time.time() + try: + return super()._dispatch(data) + finally: + got2 = time.time() + msg, seq, args = brine.load(data) + sent = self.timings.pop(seq, got1) + if msg == consts.MSG_REQUEST: + handler, args = args + str_handler = f":req={_msg_to_name['HANDLE'][handler]}" + else: + str_handler = "" + str_args = str(args).replace("\r", "").replace("\n", "\tNEWLINE\t") + with self.logLock: + for logfile in self.logfiles: + with open(logfile, "a") as out: + out.write( + f"recv:timing={got1 - sent}+{got2 - got1}:msg={_msg_to_name['MSG'][msg]}:seq={seq}{str_handler}:args={str_args}\n" + ) class _Logger: def __init__(self, conn, logname): self.conn = conn @@ -296,18 +308,9 @@ def __exit__(self, *a, **kw): def _logmore(self, logname): return self._Logger(self, logname) - def _init_deliver(self): - self._remote_batch_loads = self.modules[ - "modin.experimental.cloud.rpyc_proxy" - ]._batch_loads - self._remote_dumps = self.modules["rpyc.lib.compat"].pickle.dumps - self._remote_tuplize = self.modules[ - "modin.experimental.cloud.rpyc_proxy" - ]._tuplize - class WrappingService(rpyc.ClassicService): - _protocol = WrappingConnection + _protocol = TracingWrappingConnection if _TRACE_RPYC else WrappingConnection def on_connect(self, conn): super().on_connect(conn)