-
Notifications
You must be signed in to change notification settings - Fork 246
/
Copy pathprotocol.py
464 lines (430 loc) · 16.8 KB
/
protocol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
"""
The RPyC protocol
"""
import sys
import select
import weakref
import itertools
import cPickle as pickle
from threading import Lock
from rpyc.utils.lib import WeakValueDict, RefCountingColl
from rpyc.core import consts, brine, vinegar, netref
from rpyc.core.async import AsyncResult
class PingError(Exception):
pass
DEFAULT_CONFIG = dict(
# ATTRIBUTES
allow_safe_attrs = True,
allow_exposed_attrs = True,
allow_public_attrs = False,
allow_all_attrs = False,
safe_attrs = set(['__abs__', '__add__', '__and__', '__cmp__', '__contains__',
'__delitem__', '__delslice__', '__div__', '__divmod__', '__doc__',
'__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__',
'__getslice__', '__gt__', '__hash__', '__hex__', '__iadd__', '__iand__',
'__idiv__', '__ifloordiv__', '__ilshift__', '__imod__', '__imul__',
'__index__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__',
'__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__',
'__long__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__',
'__neg__', '__new__', '__nonzero__', '__oct__', '__or__', '__pos__',
'__pow__', '__radd__', '__rand__', '__rdiv__', '__rdivmod__', '__repr__',
'__rfloordiv__', '__rlshift__', '__rmod__', '__rmul__', '__ror__',
'__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__',
'__rxor__', '__setitem__', '__setslice__', '__str__', '__sub__',
'__truediv__', '__xor__', 'next', '__length_hint__', '__enter__',
'__exit__', ]),
exposed_prefix = "exposed_",
allow_getattr = True,
allow_setattr = False,
allow_delattr = False,
# EXCEPTIONS
include_local_traceback = True,
instantiate_custom_exceptions = False,
import_custom_exceptions = False,
instantiate_oldstyle_exceptions = False, # which don't derive from Exception
propagate_SystemExit_locally = False, # whether to propagate SystemExit locally or to the other party
# MISC
allow_pickle = False,
connid = None,
credentials = None,
)
_connection_id_generator = itertools.count(1)
class Connection(object):
"""The RPyC connection (also know as the RPyC protocol).
* service: the service to expose
* channel: the channcel over which messages are passed
* config: this connection's config dict (overriding parameters from the
default config dict)
* _lazy: whether or not to initialize the service with the creation of the
connection. default is True. if set to False, you will need to call
_init_service manually later
"""
def __init__(self, service, channel, config = {}, _lazy = False):
self._closed = True
self._config = DEFAULT_CONFIG.copy()
self._config.update(config)
if self._config["connid"] is None:
self._config["connid"] = "conn%d" % (_connection_id_generator.next(),)
self._channel = channel
self._seqcounter = itertools.count()
self._recvlock = Lock()
self._sendlock = Lock()
self._sync_replies = {}
self._async_callbacks = {}
self._local_objects = RefCountingColl()
self._last_traceback = None
self._proxy_cache = WeakValueDict()
self._netref_classes_cache = {}
self._remote_root = None
self._local_root = service(weakref.proxy(self))
if not _lazy:
self._init_service()
self._closed = False
def _init_service(self):
self._local_root.on_connect()
def __del__(self):
self.close()
def __enter__(self):
return self
def __exit__(self, t, v, tb):
self.close()
def __repr__(self):
a, b = object.__repr__(self).split(" object ")
return "%s %r object %s" % (a, self._config["connid"], b)
#
# IO
#
def _cleanup(self, _anyway = True):
if self._closed and not _anyway:
return
self._closed = True
self._channel.close()
self._local_root.on_disconnect()
self._sync_replies.clear()
self._async_callbacks.clear()
self._local_objects.clear()
self._proxy_cache.clear()
self._netref_classes_cache.clear()
self._last_traceback = None
self._last_traceback = None
self._remote_root = None
self._local_root = None
#self._seqcounter = None
#self._config.clear()
def close(self, _catchall = True):
if self._closed:
return
self._closed = True
try:
try:
self._async_request(consts.HANDLE_CLOSE)
except EOFError:
pass
except Exception:
if not _catchall:
raise
finally:
self._cleanup(_anyway = True)
@property
def closed(self):
return self._closed
def fileno(self):
return self._channel.fileno()
def ping(self, data = "the world is a vampire!" * 20, timeout = 3):
"""assert that the other party is functioning properly"""
res = self.async_request(consts.HANDLE_PING, data, timeout = timeout)
if res.value != data:
raise PingError("echo mismatches sent data")
def _send(self, msg, seq, args):
data = brine.dump((msg, seq, args))
self._sendlock.acquire()
try:
self._channel.send(data)
finally:
self._sendlock.release()
def _send_request(self, handler, args):
seq = self._seqcounter.next()
self._send(consts.MSG_REQUEST, seq, (handler, self._box(args)))
return seq
def _send_reply(self, seq, obj):
self._send(consts.MSG_REPLY, seq, self._box(obj))
def _send_exception(self, seq, exctype, excval, exctb):
exc = vinegar.dump(exctype, excval, exctb,
include_local_traceback = self._config["include_local_traceback"])
self._send(consts.MSG_EXCEPTION, seq, exc)
#
# boxing
#
def _box(self, obj):
"""store a local object in such a way that it could be recreated on
the remote party either by-value or by-reference"""
if brine.dumpable(obj):
return consts.LABEL_VALUE, obj
if type(obj) is tuple:
return consts.LABEL_TUPLE, tuple(self._box(item) for item in obj)
elif isinstance(obj, netref.BaseNetref) and obj.____conn__() is self:
return consts.LABEL_LOCAL_REF, obj.____oid__
else:
self._local_objects.add(obj)
cls = getattr(obj, "__class__", type(obj))
return consts.LABEL_REMOTE_REF, (id(obj), cls.__name__, cls.__module__)
def _unbox(self, package):
"""recreate a local object representation of the remote object: if the
object is passed by value, just return it; if the object is passed by
reference, create a netref to it"""
label, value = package
if label == consts.LABEL_VALUE:
return value
if label == consts.LABEL_TUPLE:
return tuple(self._unbox(item) for item in value)
if label == consts.LABEL_LOCAL_REF:
return self._local_objects[value]
if label == consts.LABEL_REMOTE_REF:
oid, clsname, modname = value
if oid in self._proxy_cache:
return self._proxy_cache[oid]
proxy = self._netref_factory(oid, clsname, modname)
self._proxy_cache[oid] = proxy
return proxy
raise ValueError("invalid label %r" % (label,))
def _netref_factory(self, oid, clsname, modname):
typeinfo = (clsname, modname)
if typeinfo in self._netref_classes_cache:
cls = self._netref_classes_cache[typeinfo]
elif typeinfo in netref.builtin_classes_cache:
cls = netref.builtin_classes_cache[typeinfo]
else:
info = self.sync_request(consts.HANDLE_INSPECT, oid)
cls = netref.class_factory(clsname, modname, info)
self._netref_classes_cache[typeinfo] = cls
return cls(weakref.ref(self), oid)
#
# dispatching
#
def _dispatch_request(self, seq, raw_args):
try:
handler, args = raw_args
args = self._unbox(args)
res = self._HANDLERS[handler](self, *args)
except KeyboardInterrupt:
raise
except:
t, v, tb = sys.exc_info()
self._last_traceback = tb
if t is SystemExit and self._config["propagate_SystemExit_locally"]:
raise
self._send_exception(seq, t, v, tb)
else:
self._send_reply(seq, res)
def _dispatch_reply(self, seq, raw):
obj = self._unbox(raw)
if seq in self._async_callbacks:
self._async_callbacks.pop(seq)(False, obj)
else:
self._sync_replies[seq] = (False, obj)
def _dispatch_exception(self, seq, raw):
obj = vinegar.load(raw,
import_custom_exceptions = self._config["import_custom_exceptions"],
instantiate_custom_exceptions = self._config["instantiate_custom_exceptions"],
instantiate_oldstyle_exceptions = self._config["instantiate_oldstyle_exceptions"])
if seq in self._async_callbacks:
self._async_callbacks.pop(seq)(True, obj)
else:
self._sync_replies[seq] = (True, obj)
#
# serving
#
def _recv(self, timeout, wait_for_lock):
if not self._recvlock.acquire(wait_for_lock):
return None
try:
try:
if self._channel.poll(timeout):
data = self._channel.recv()
else:
data = None
except EOFError:
self.close()
raise
finally:
self._recvlock.release()
return data
def _dispatch(self, data):
msg, seq, args = brine.load(data)
if msg == consts.MSG_REQUEST:
self._dispatch_request(seq, args)
elif msg == consts.MSG_REPLY:
self._dispatch_reply(seq, args)
elif msg == consts.MSG_EXCEPTION:
self._dispatch_exception(seq, args)
else:
raise ValueError("invalid message type: %r" % (msg,))
def poll(self, timeout = 0):
"""serve a single transaction, should one arrives in the given
interval. note that handling a request/reply may trigger nested
requests, which are all part of the transaction.
returns True if one was served, False otherwise"""
data = self._recv(timeout, wait_for_lock = False)
if not data:
return False
self._dispatch(data)
return True
def serve(self, timeout = 1):
"""serve a single request or reply that arrives within the given
time frame (default is 1 sec). note that the dispatching of a request
might trigger multiple (nested) requests, thus this function may be
reentrant. returns True if a request or reply were received, False
otherwise."""
data = self._recv(timeout, wait_for_lock = True)
if not data:
return False
self._dispatch(data)
return True
def serve_all(self):
"""serve all requests and replies while the connection is alive"""
try:
try:
while True:
self.serve(0.1)
except select.error:
if not self.closed:
raise e
except EOFError:
pass
finally:
self.close()
def poll_all(self, timeout = 0):
"""serve all requests and replies that arrive within the given interval.
returns True if at least one was served, False otherwise"""
at_least_once = False
try:
while self.poll(timeout):
at_least_once = True
except EOFError:
pass
return at_least_once
#
# requests
#
def sync_request(self, handler, *args):
"""send a request and wait for the reply to arrive"""
seq = self._send_request(handler, args)
while seq not in self._sync_replies:
self.serve(0.1)
isexc, obj = self._sync_replies.pop(seq)
if isexc:
raise obj
else:
return obj
def _async_request(self, handler, args = (), callback = (lambda a, b: None)):
seq = self._send_request(handler, args)
self._async_callbacks[seq] = callback
def async_request(self, handler, *args, **kwargs):
"""send a request and return an AsyncResult object, which will
eventually hold the reply"""
timeout = kwargs.pop("timeout", None)
if kwargs:
raise TypeError("got unexpected keyword argument %r" % (kwargs.keys()[0],))
res = AsyncResult(weakref.proxy(self))
self._async_request(handler, args, res)
if timeout is not None:
res.set_expiry(timeout)
return res
@property
def root(self):
"""fetch the root object of the other party"""
if self._remote_root is None:
self._remote_root = self.sync_request(consts.HANDLE_GETROOT)
return self._remote_root
#
# attribute access
#
def _check_attr(self, obj, name):
if self._config["allow_exposed_attrs"]:
if name.startswith(self._config["exposed_prefix"]):
name2 = name
else:
name2 = self._config["exposed_prefix"] + name
if hasattr(obj, name2):
return name2
if self._config["allow_all_attrs"]:
return name
if self._config["allow_safe_attrs"] and name in self._config["safe_attrs"]:
return name
if self._config["allow_public_attrs"] and not name.startswith("_"):
return name
return False
def _access_attr(self, oid, name, args, overrider, param, default):
if type(name) is not str:
raise TypeError("attr name must be a string")
obj = self._local_objects[oid]
accessor = getattr(type(obj), overrider, None)
if accessor is None:
name2 = self._check_attr(obj, name)
if not self._config[param] or not name2:
raise AttributeError("cannot access %r" % (name,))
accessor = default
name = name2
return accessor(obj, name, *args)
#
# handlers
#
def _handle_ping(self, data):
return data
def _handle_close(self):
self._cleanup()
def _handle_getroot(self):
return self._local_root
def _handle_del(self, oid):
self._local_objects.decref(oid)
def _handle_repr(self, oid):
return repr(self._local_objects[oid])
def _handle_str(self, oid):
return str(self._local_objects[oid])
def _handle_cmp(self, oid, other):
# cmp() might enter recursive resonance... yet another workaround
#return cmp(self._local_objects[oid], other)
obj = self._local_objects[oid]
try:
return type(obj).__cmp__(obj, other)
except TypeError:
return NotImplemented
def _handle_hash(self, oid):
return hash(self._local_objects[oid])
def _handle_call(self, oid, args, kwargs):
return self._local_objects[oid](*args, **dict(kwargs))
def _handle_dir(self, oid):
return tuple(dir(self._local_objects[oid]))
def _handle_inspect(self, oid):
return tuple(netref.inspect_methods(self._local_objects[oid]))
def _handle_getattr(self, oid, name):
return self._access_attr(oid, name, (), "_rpyc_getattr", "allow_getattr", getattr)
def _handle_delattr(self, oid, name):
return self._access_attr(oid, name, (), "_rpyc_delattr", "allow_delattr", delattr)
def _handle_setattr(self, oid, name, value):
return self._access_attr(oid, name, (value,), "_rpyc_setattr", "allow_setattr", setattr)
def _handle_callattr(self, oid, name, args, kwargs):
return self._handle_getattr(oid, name)(*args, **dict(kwargs))
def _handle_pickle(self, oid, proto):
if not self._config["allow_pickle"]:
raise ValueError("pickling is disabled")
return pickle.dumps(self._local_objects[oid], proto)
def _handle_buffiter(self, oid, count):
items = []
obj = self._local_objects[oid]
for i in xrange(count):
try:
items.append(obj.next())
except StopIteration:
break
return tuple(items)
# collect handlers
_HANDLERS = {}
for name, obj in locals().items():
if name.startswith("_handle_"):
name2 = "HANDLE_" + name[8:].upper()
if hasattr(consts, name2):
_HANDLERS[getattr(consts, name2)] = obj
else:
raise NameError("no constant defined for %r", name)
del name, name2, obj