Skip to content

Commit

Permalink
pattern matching: optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
leogama committed Jun 20, 2022
1 parent 4d1db4a commit fc84eb7
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 81 deletions.
107 changes: 72 additions & 35 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _trace(boolean):
OLDER = (PY3 and sys.hexversion < 0x3040000) or (sys.hexversion < 0x2070ab1)
OLD33 = (sys.hexversion < 0x3030000)
OLD37 = (sys.hexversion < 0x3070000)
OLD38 = (sys.hexversion < 0x3080000)
OLD39 = (sys.hexversion < 0x3090000)
OLD310 = (sys.hexversion < 0x30a0000)
PY34 = (0x3040000 <= sys.hexversion < 0x3050000)
Expand Down Expand Up @@ -772,6 +773,44 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None,
# assert id(fglobals) == id(func.__globals__)
return func

class match:
"""
Make avaialable a limited structural pattern matching-like syntax for Python < 3.10
Patterns can be only (tuples of) types currently.
Inspired by the package pattern-matching-PEP634.
Usage:
>>> with match(args) as m:
>>> if m.case(int, x=list):
>>> # use m.x
>>> elif m.case(int, x=list, y=(int, float)):
>>> # use m.x and m.y
Equivalent native code for Python >= 3.10:
>>> match args:
>>> case (int, list(x)):
>>> # use x
>>> case (int, list(x), int()|float() as y):
>>> # use x and y
"""
def __init__(self, value):
self.value = value
def __enter__(self):
return self
def __exit__(self, *exc_info):
return False
def __getattr__(self, item):
return self.vars[item]
def case(self, *args, **kwargs):
"""just handles tuple patterns"""
if len(kwargs) != len(self.value):
return False
matches = all(isinstance(arg, pat) for arg, pat in zip(self.value, kwargs.values()))
if matches:
self.vars = dict(zip(kwargs.keys(), self.value))
return matches

CODE_PARAMS = [
# Version New attribute CodeType parameters
((3,11,'a'), 'co_endlinetable', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name qualname firstlineno linetable endlinetable columntable exceptiontable freevars cellvars'),
Expand All @@ -780,12 +819,17 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None,
((3,8), 'co_posonlyargcount', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name firstlineno lnotab freevars cellvars'),
((3,7), 'co_kwonlyargcount', 'argcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name firstlineno lnotab freevars cellvars'),
]

for version, new_attr, params_list in CODE_PARAMS:
for version, new_attr, params in CODE_PARAMS:
if hasattr(CodeType, new_attr):
CODE_PARAMS = params_list.split()
CODE_VERSION = version
params = params.split()
break
CODE_VERSION = version
N_COMMON = params.index('name') + 1
VERSION_CODE_PARAMS = params[N_COMMON:]
ENCODE_PARAMS = set(VERSION_CODE_PARAMS).intersection(
['lnotab', 'linetable', 'endlinetable', 'columntable', 'exceptiontable'])
if OLD38:
N_COMMON += 1

def _create_code(*args):
if not isinstance(args[0], int): # co_lnotab stored from >= 3.10
Expand All @@ -794,58 +838,51 @@ def _create_code(*args):
else: # from < 3.10 (or pre-LNOTAB storage)
LNOTAB = b''

if not isinstance(args[5], int):
args = args[0], 0, *args[1:] # co_posonlyargcount (from python <= 3.7)
common_args, args = list(args[:N_COMMON]), args[N_COMMON:]
if hasattr(common_args[6], 'encode'): # co_code
common_args[6] = common_args[6].encode()

BYTES = (bytes, str)
BYTES_NONE = (bytes, str, type(None))
encode_bytes = lambda fields, patterns: {
key: val.encode() if pat in (BYTES, BYTES_NONE) and isinstance(val, str) else val \
for (key, val), pat in zip(fields, patterns)
}

with match(args) as m:
# Python 3.11
if m.case(argcount=int, posonlyargcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int,
code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str, qualname=str,
firstlineno=int, linetable=BYTES,
exceptiontable=BYTES, freevars=tuple, cellvars=tuple):
fields = encode_bytes(m.vars.items(), m.patterns)
if m.case(qualname=str, firstlineno=int, linetable=BYTES, exceptiontable=BYTES,
freevars=tuple, cellvars=tuple):
fields = m.vars

# Python 3.11a
elif m.case(argcount=int, posonlyargcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int,
code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str, qualname=str,
firstlineno=int, linetable=BYTES, endlinetable=BYTES_NONE, columntable=BYTES_NONE,
exceptiontable=BYTES, freevars=tuple, cellvars=tuple):
fields = encode_bytes(m.vars.items(), m.patterns)
elif m.case(qualname=str, firstlineno=int, linetable=BYTES, endlinetable=BYTES_NONE,
columntable=BYTES_NONE, exceptiontable=BYTES, freevars=tuple, cellvars=tuple):
fields = m.vars

# Python 3.10 or 3.9/3.8
elif m.case(argcount=int, posonlyargcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int,
code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str,
firstlineno=int, LNOTAB_OR_LINETABLE=BYTES,
freevars=tuple, cellvars=tuple):
fields = encode_bytes(m.vars.items(), m.patterns)
elif m.case(firstlineno=int, LNOTAB_OR_LINETABLE=BYTES, freevars=tuple, cellvars=tuple):
fields = m.vars
key = 'linetable' if CODE_VERSION >= (3,10) else 'lnotab'
fields[key] = fields['LNOTAB_OR_LINETABLE']
fields[key] = m.vars['LNOTAB_OR_LINETABLE']

# Python 3.7
elif m.case(argcount=int, kwonlyargcount=int, nlocals=int, stacksize=int, flags=int,
code=BYTES, consts=tuple, names=tuple, varnames=tuple, filename=str, name=str,
firstlineno=int, lnotab=BYTES,
freevars=tuple, cellvars=tuple):
fields = encode_bytes(m.vars.items(), m.patterns)
elif m.case(firstlineno=int, lnotab=BYTES, freevars=tuple, cellvars=tuple):
fields = m.vars

fields.setdefault('posonlyargcount', 0) # from python <= 3.7
fields.setdefault('qualname', fields['name']) # from python <= 3.10
fields = {k: (v.encode() if k in ENCODE_PARAMS and hasattr(v, 'encode') else v) for k, v in fields.items()}
fields.setdefault('qualname', common_args[-1]) # from python <= 3.10
fields.setdefault('exceptiontable', b'') # from python <= 3.10
fields.setdefault('endlinetable', None) # from python != 3.11a
fields.setdefault('columntable', None) # from python != 3.11a

# Special case: lnotab and linetable
# Special case: co_lnotab and co_linetable
if CODE_VERSION >= (3,10):
fields.setdefault('linetable', b'')
else:
fields.setdefault('lnotab', LNOTAB)

args = (fields[param] for param in CODE_PARAMS)
return CodeType(*args)
args = (fields[param] for param in VERSION_CODE_PARAMS)
if OLD38:
del common_args[1] # co_posonlyargcount
return CodeType(*common_args, *args)

def _create_ftype(ftypeobj, func, args, kwds):
if kwds is None:
Expand Down
46 changes: 0 additions & 46 deletions dill/_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,49 +264,3 @@ def _delattr(cell, name):

_setattr = Getattr(_dill, '_setattr', setattr)
_delattr = Getattr(_dill, '_delattr', delattr)

# Structural Pattern Matching for reduce tuples.
@move_to(_dill)
class match:
"""
Make avaialable a limited pattern matching-like syntax for Python < 3.10
Patterns can be literals ("example") or (tuples of) types.
Inspired by the package pattern-matching-PEP634.
Usage:
>>> with match(args) as m:
>>> if m.case(int, x=list):
>>> # use m.x
>>> elif m.case(int, x=list, y=int):
>>> # use m.x and m.y
Equivalent native code for Python >= 3.10:
>>> match args:
>>> case (int, list(x)):
>>> # use x
>>> case (int, list(x), int(y):
>>> # use x and y
"""
def __init__(self, value):
self.value = value
self.patterns = None
self.vars = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def __getattr__(self, item):
return self.vars[item]
def case(self, *args, **kwargs):
"""just handles tuple patterns"""
self.patterns = args + tuple(kwargs.values())
if len(self.value) != len(self.patterns):
return False
is_type = lambda pat: isinstance(pat, type) or \
isinstance(pat, tuple) and all(isinstance(item, type) for item in pat)
matches = all(isinstance(arg, pat) if is_type(pat) else arg == pat \
for arg, pat in zip(self.value, self.patterns))
if matches and kwargs:
self.vars = dict(zip(kwargs.keys(), self.value[len(args):]))
return matches

0 comments on commit fc84eb7

Please sign in to comment.