From fc84eb7cfabd8f78b650d82b4c707e07f6b35df8 Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Mon, 20 Jun 2022 11:28:31 -0300 Subject: [PATCH] pattern matching: optimizations --- dill/_dill.py | 107 +++++++++++++++++++++++++++++++++---------------- dill/_shims.py | 46 --------------------- 2 files changed, 72 insertions(+), 81 deletions(-) diff --git a/dill/_dill.py b/dill/_dill.py index 07bc28a3..9a2f5b92 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -40,6 +40,7 @@ def _trace(boolean): OLDER = (PY3 and sys.hexversion < 0x3040000) or (sys.hexversion < 0x2070ab1) OLD33 = (sys.hexversion < 0x3030000) OLD37 = (sys.hexversion < 0x3070000) +OLD38 = (sys.hexversion < 0x3080000) OLD39 = (sys.hexversion < 0x3090000) OLD310 = (sys.hexversion < 0x30a0000) PY34 = (0x3040000 <= sys.hexversion < 0x3050000) @@ -772,6 +773,44 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None, # assert id(fglobals) == id(func.__globals__) return func +class match: + """ + Make avaialable a limited structural pattern matching-like syntax for Python < 3.10 + + Patterns can be only (tuples of) types currently. + Inspired by the package pattern-matching-PEP634. + + Usage: + >>> with match(args) as m: + >>> if m.case(int, x=list): + >>> # use m.x + >>> elif m.case(int, x=list, y=(int, float)): + >>> # use m.x and m.y + + Equivalent native code for Python >= 3.10: + >>> match args: + >>> case (int, list(x)): + >>> # use x + >>> case (int, list(x), int()|float() as y): + >>> # use x and y + """ + def __init__(self, value): + self.value = value + def __enter__(self): + return self + def __exit__(self, *exc_info): + return False + def __getattr__(self, item): + return self.vars[item] + def case(self, *args, **kwargs): + """just handles tuple patterns""" + if len(kwargs) != len(self.value): + return False + matches = all(isinstance(arg, pat) for arg, pat in zip(self.value, kwargs.values())) + if matches: + self.vars = dict(zip(kwargs.keys(), self.value)) + return matches + CODE_PARAMS = [ # Version New attribute CodeType parameters ((3,11,'a'), 'co_endlinetable', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name qualname firstlineno linetable endlinetable columntable exceptiontable freevars cellvars'), @@ -780,12 +819,17 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None, ((3,8), 'co_posonlyargcount', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name firstlineno lnotab freevars cellvars'), ((3,7), 'co_kwonlyargcount', 'argcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name firstlineno lnotab freevars cellvars'), ] - -for version, new_attr, params_list in CODE_PARAMS: +for version, new_attr, params in CODE_PARAMS: if hasattr(CodeType, new_attr): - CODE_PARAMS = params_list.split() - CODE_VERSION = version + params = params.split() break +CODE_VERSION = version +N_COMMON = params.index('name') + 1 +VERSION_CODE_PARAMS = params[N_COMMON:] +ENCODE_PARAMS = set(VERSION_CODE_PARAMS).intersection( + ['lnotab', 'linetable', 'endlinetable', 'columntable', 'exceptiontable']) +if OLD38: + N_COMMON += 1 def _create_code(*args): if not isinstance(args[0], int): # co_lnotab stored from >= 3.10 @@ -794,58 +838,51 @@ def _create_code(*args): else: # from < 3.10 (or pre-LNOTAB storage) LNOTAB = b'' + if not isinstance(args[5], int): + args = args[0], 0, *args[1:] # co_posonlyargcount (from python <= 3.7) + common_args, args = list(args[:N_COMMON]), args[N_COMMON:] + if hasattr(common_args[6], 'encode'): # co_code + common_args[6] = common_args[6].encode() + BYTES = (bytes, str) BYTES_NONE = (bytes, str, type(None)) - encode_bytes = lambda fields, patterns: { - key: val.encode() if pat in (BYTES, BYTES_NONE) and isinstance(val, str) else val \ - for (key, val), pat in zip(fields, patterns) - } - with match(args) as m: # Python 3.11 - if m.case(argcount=int, posonlyargcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int, - code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str, qualname=str, - firstlineno=int, linetable=BYTES, - exceptiontable=BYTES, freevars=tuple, cellvars=tuple): - fields = encode_bytes(m.vars.items(), m.patterns) + if m.case(qualname=str, firstlineno=int, linetable=BYTES, exceptiontable=BYTES, + freevars=tuple, cellvars=tuple): + fields = m.vars # Python 3.11a - elif m.case(argcount=int, posonlyargcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int, - code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str, qualname=str, - firstlineno=int, linetable=BYTES, endlinetable=BYTES_NONE, columntable=BYTES_NONE, - exceptiontable=BYTES, freevars=tuple, cellvars=tuple): - fields = encode_bytes(m.vars.items(), m.patterns) + elif m.case(qualname=str, firstlineno=int, linetable=BYTES, endlinetable=BYTES_NONE, + columntable=BYTES_NONE, exceptiontable=BYTES, freevars=tuple, cellvars=tuple): + fields = m.vars # Python 3.10 or 3.9/3.8 - elif m.case(argcount=int, posonlyargcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int, - code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str, - firstlineno=int, LNOTAB_OR_LINETABLE=BYTES, - freevars=tuple, cellvars=tuple): - fields = encode_bytes(m.vars.items(), m.patterns) + elif m.case(firstlineno=int, LNOTAB_OR_LINETABLE=BYTES, freevars=tuple, cellvars=tuple): + fields = m.vars key = 'linetable' if CODE_VERSION >= (3,10) else 'lnotab' - fields[key] = fields['LNOTAB_OR_LINETABLE'] + fields[key] = m.vars['LNOTAB_OR_LINETABLE'] # Python 3.7 - elif m.case(argcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int, - code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str, - firstlineno=int, lnotab=BYTES, - freevars=tuple, cellvars=tuple): - fields = encode_bytes(m.vars.items(), m.patterns) + elif m.case(firstlineno=int, lnotab=BYTES, freevars=tuple, cellvars=tuple): + fields = m.vars - fields.setdefault('posonlyargcount', 0) # from python <= 3.7 - fields.setdefault('qualname', fields['name']) # from python <= 3.10 + fields = {k: (v.encode() if k in ENCODE_PARAMS and hasattr(v, 'encode') else v) for k, v in fields.items()} + fields.setdefault('qualname', common_args[-1]) # from python <= 3.10 fields.setdefault('exceptiontable', b'') # from python <= 3.10 fields.setdefault('endlinetable', None) # from python != 3.11a fields.setdefault('columntable', None) # from python != 3.11a - # Special case: lnotab and linetable + # Special case: co_lnotab and co_linetable if CODE_VERSION >= (3,10): fields.setdefault('linetable', b'') else: fields.setdefault('lnotab', LNOTAB) - args = (fields[param] for param in CODE_PARAMS) - return CodeType(*args) + args = (fields[param] for param in VERSION_CODE_PARAMS) + if OLD38: + del common_args[1] # co_posonlyargcount + return CodeType(*common_args, *args) def _create_ftype(ftypeobj, func, args, kwds): if kwds is None: diff --git a/dill/_shims.py b/dill/_shims.py index c9cc8525..6bda5136 100644 --- a/dill/_shims.py +++ b/dill/_shims.py @@ -264,49 +264,3 @@ def _delattr(cell, name): _setattr = Getattr(_dill, '_setattr', setattr) _delattr = Getattr(_dill, '_delattr', delattr) - -# Structural Pattern Matching for reduce tuples. -@move_to(_dill) -class match: - """ - Make avaialable a limited pattern matching-like syntax for Python < 3.10 - - Patterns can be literals ("example") or (tuples of) types. - Inspired by the package pattern-matching-PEP634. - - Usage: - >>> with match(args) as m: - >>> if m.case(int, x=list): - >>> # use m.x - >>> elif m.case(int, x=list, y=int): - >>> # use m.x and m.y - - Equivalent native code for Python >= 3.10: - >>> match args: - >>> case (int, list(x)): - >>> # use x - >>> case (int, list(x), int(y): - >>> # use x and y - """ - def __init__(self, value): - self.value = value - self.patterns = None - self.vars = None - def __enter__(self): - return self - def __exit__(self, exc_type, exc_val, exc_tb): - return False - def __getattr__(self, item): - return self.vars[item] - def case(self, *args, **kwargs): - """just handles tuple patterns""" - self.patterns = args + tuple(kwargs.values()) - if len(self.value) != len(self.patterns): - return False - is_type = lambda pat: isinstance(pat, type) or \ - isinstance(pat, tuple) and all(isinstance(item, type) for item in pat) - matches = all(isinstance(arg, pat) if is_type(pat) else arg == pat \ - for arg, pat in zip(self.value, self.patterns)) - if matches and kwargs: - self.vars = dict(zip(kwargs.keys(), self.value[len(args):])) - return matches