Skip to content

Commit

Permalink
Improve PEP 695 implementation
Browse files Browse the repository at this point in the history
Refs   #757.
  • Loading branch information
evhub committed Mar 21, 2024
1 parent 9a9b41e commit 76c956c
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 21 deletions.
59 changes: 48 additions & 11 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3180,7 +3180,16 @@ def classdef_handle(self, original, loc, tokens):
"""Process class definitions."""
decorators, name, paramdefs, classlist_toks, body = tokens

out = "".join(paramdefs) + decorators + "class " + name
out = ""

# paramdefs are type params on >= 3.12 and type var assignments on < 3.12
if paramdefs:
if self.target_info >= (3, 12):
name += "[" + ", ".join(paramdefs) + "]"
else:
out += "".join(paramdefs)

out += decorators + "class " + name

# handle classlist
base_classes = []
Expand Down Expand Up @@ -3210,7 +3219,7 @@ def classdef_handle(self, original, loc, tokens):

base_classes.append(join_args(pos_args, star_args, kwd_args, dubstar_args))

if paramdefs:
if paramdefs and self.target_info < (3, 12):
base_classes.append(self.get_generic_for_typevars())

if not classlist_toks and not self.target.startswith("3"):
Expand Down Expand Up @@ -3442,9 +3451,16 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts,
IMPORTANT: Any changes to assemble_data must be reflected in the
definition of Expected in header.py_template.
"""
print(paramdefs)
# create class
out = [
"".join(paramdefs),
out = []
if paramdefs:
# paramdefs are type params on >= 3.12 and type var assignments on < 3.12
if self.target_info >= (3, 12):
name += "[" + ", ".join(paramdefs) + "]"
else:
out += ["".join(paramdefs)]
out += [
decorators,
"class ",
name,
Expand All @@ -3453,7 +3469,7 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts,
]
if inherit is not None:
out += [", ", inherit]
if paramdefs:
if paramdefs and self.target_info < (3, 12):
out += [", ", self.get_generic_for_typevars()]
if not self.target.startswith("3"):
out.append(", _coconut.object")
Expand Down Expand Up @@ -4564,34 +4580,46 @@ def funcname_typeparams_handle(self, tokens):
return name
else:
name, paramdefs = tokens
return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False)
# paramdefs are type params on >= 3.12 and type var assignments on < 3.12
if self.target_info >= (3, 12):
return name + "[" + ", ".join(paramdefs) + "]"
else:
return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False)

funcname_typeparams_handle.ignore_one_token = True

def type_param_handle(self, original, loc, tokens):
"""Compile a type param into an assignment."""
args = ""
raw_bound = None
bound_op = None
bound_op_type = ""
stars = ""
if "TypeVar" in tokens:
TypeVarFunc = "TypeVar"
bound_op_type = "bound"
if len(tokens) == 2:
name_loc, name = tokens
else:
name_loc, name, bound_op, bound = tokens
# raw_bound is for >=3.12, so it is for_py_typedef, but args is for <3.12, so it isn't
raw_bound = self.wrap_typedef(bound, for_py_typedef=True)
args = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False)
elif "TypeVar constraint" in tokens:
TypeVarFunc = "TypeVar"
bound_op_type = "constraint"
name_loc, name, bound_op, constraints = tokens
# for_py_typedef is different in the two cases here as above
raw_bound = ", ".join(self.wrap_typedef(c, for_py_typedef=True) for c in constraints)
args = ", " + ", ".join(self.wrap_typedef(c, for_py_typedef=False) for c in constraints)
elif "TypeVarTuple" in tokens:
TypeVarFunc = "TypeVarTuple"
name_loc, name = tokens
stars = "*"
elif "ParamSpec" in tokens:
TypeVarFunc = "ParamSpec"
name_loc, name = tokens
stars = "**"
else:
raise CoconutInternalException("invalid type_param tokens", tokens)

Expand All @@ -4612,8 +4640,14 @@ def type_param_handle(self, original, loc, tokens):
loc,
)

# on >= 3.12, return a type param
if self.target_info >= (3, 12):
return stars + name + (": " + raw_bound if raw_bound is not None else "")

# on < 3.12, return a type variable assignment

kwargs = ""
# uncomment these lines whenever mypy adds support for infer_variance in TypeVar
# TODO: uncomment these lines whenever mypy adds support for infer_variance in TypeVar
# (and remove the warning about it in the DOCS)
# if TypeVarFunc == "TypeVar":
# kwargs += ", infer_variance=True"
Expand Down Expand Up @@ -4644,6 +4678,7 @@ def type_param_handle(self, original, loc, tokens):

def get_generic_for_typevars(self):
"""Get the Generic instances for the current typevars."""
internal_assert(self.target_info < (3, 12), "get_generic_for_typevars should only be used on targets < 3.12")
typevar_info = self.current_parsing_context("typevars")
internal_assert(typevar_info is not None, "get_generic_for_typevars called with no typevars")
generics = []
Expand Down Expand Up @@ -4677,16 +4712,18 @@ def type_alias_stmt_handle(self, tokens):
paramdefs = ()
else:
name, paramdefs, typedef = tokens
out = "".join(paramdefs)

# paramdefs are type params on >= 3.12 and type var assignments on < 3.12
if self.target_info >= (3, 12):
out += "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True)
if paramdefs:
name += "[" + ", ".join(paramdefs) + "]"
return "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True)
else:
out += self.typed_assign_stmt_handle([
return "".join(paramdefs) + self.typed_assign_stmt_handle([
name,
"_coconut.typing.TypeAlias",
self.wrap_typedef(typedef, for_py_typedef=False),
])
return out

def where_item_handle(self, tokens):
"""Manage where items."""
Expand Down
2 changes: 1 addition & 1 deletion coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE}
reiterables = abc.Sequence, abc.Mapping, abc.Set
fmappables = list, tuple, dict, set, frozenset, bytes, bytearray
abc.Sequence.register(collections.deque)
Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray}
Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, {lstatic}min{rstatic}, {lstatic}max{rstatic}, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray}
@_coconut.functools.wraps(_coconut.functools.partial)
def _coconut_partial(_coconut_func, *args, **kwargs):
partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_path_env_var(env_var, default):
PY38
and not WINDOWS
and not PYPY
# disabled until MyPy supports PEP 695
# TODO: disabled until MyPy supports PEP 695
and not PY312
)
XONSH = (
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.1.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 5
DEVELOP = 6
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
10 changes: 5 additions & 5 deletions coconut/tests/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,11 +1092,11 @@ def test_bbopt(self):
if not PYPY and PY38 and not PY310:
install_bbopt()

def test_pyprover(self):
with using_paths(pyprover):
comp_pyprover()
if PY38:
run_pyprover()
# def test_pyprover(self):
# with using_paths(pyprover):
# comp_pyprover()
# if PY38:
# run_pyprover()

def test_pyston(self):
with using_paths(pyston):
Expand Down
3 changes: 1 addition & 2 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,7 @@ type Num = int | float""".strip())
assert parse("type L[T] = list[T]").strip().endswith("""
# Compiled Coconut: -----------------------------------------------------------

_coconut_typevar_T_0 = _coconut.typing.TypeVar("_coconut_typevar_T_0")
type L = list[_coconut_typevar_T_0]""".strip())
type L[T] = list[T]""".strip())

setup(line_numbers=False, minify=True)
assert parse("123 # derp", "lenient") == "123# derp"
Expand Down

0 comments on commit 76c956c

Please sign in to comment.