Skip to content

Commit

Permalink
Support fmap of bytes, bytearray
Browse files Browse the repository at this point in the history
Resolves   #826.
  • Loading branch information
evhub committed Jan 27, 2024
1 parent 5ba3d1a commit b697bc8
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 11 deletions.
2 changes: 1 addition & 1 deletion DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3385,7 +3385,7 @@ _Can't be done without a series of method definitions for each data type. See th

In Haskell, `fmap(func, obj)` takes a data type `obj` and returns a new data type with `func` mapped over the contents. Coconut's `fmap` function does the exact same thing for Coconut's [data types](#data).

`fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, and `dict` as a variant of `map` that returns back an object of the same type.
`fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray`, and `dict` as a variant of `map` that returns back an object of the same type.

For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. _Deprecated: `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them._

Expand Down
2 changes: 1 addition & 1 deletion __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,7 @@ def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U],
Supports:
* Coconut data types
* `str`, `dict`, `list`, `tuple`, `set`, `frozenset`
* `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray`
* `dict` (maps over .items())
* asynchronous iterables
* numpy arrays (uses np.vectorize)
Expand Down
1 change: 1 addition & 0 deletions _coconut/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ min = _builtins.min
max = _builtins.max
next = _builtins.next
object = _builtins.object
ord = _builtins.ord
print = _builtins.print
property = _builtins.property
range = _builtins.range
Expand Down
9 changes: 9 additions & 0 deletions coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,15 @@ def __aiter__(self):
{async_def_anext}
'''.format(**format_dict),
),
handle_bytes=pycondition(
(3,),
if_lt='''
if _coconut.isinstance(obj, _coconut.bytes):
return _coconut_base_makedata(_coconut.bytes, [func(_coconut.ord(x)) for x in obj], from_fmap=True, fallback_to_init=fallback_to_init)
''',
indent=1,
newline=True,
),
maybe_bind_lru_cache=pycondition(
(3, 2),
if_lt='''
Expand Down
15 changes: 7 additions & 8 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE}
jax_numpy_modules = {jax_numpy_modules}
tee_type = type(itertools.tee((), 1)[0])
reiterables = abc.Sequence, abc.Mapping, abc.Set
fmappables = list, tuple, dict, set, frozenset
fmappables = list, tuple, dict, set, frozenset, bytes, bytearray
abc.Sequence.register(collections.deque)
Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray}
Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray}
@_coconut.functools.wraps(_coconut.functools.partial)
def _coconut_partial(_coconut_func, *args, **kwargs):
partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs)
Expand Down Expand Up @@ -1583,8 +1583,6 @@ def _coconut_base_makedata(data_type, args, from_fmap=False, fallback_to_init=Fa
return args
if _coconut.issubclass(data_type, _coconut.str):
return "".join(args)
if _coconut.issubclass(data_type, _coconut.bytes):
return b"".join(args)
if fallback_to_init or _coconut.issubclass(data_type, _coconut.fmappables):
return data_type(args)
if from_fmap:
Expand All @@ -1602,7 +1600,7 @@ def fmap(func, obj, **kwargs):

Supports:
* Coconut data types
* `str`, `dict`, `list`, `tuple`, `set`, `frozenset`
* `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray`
* `dict` (maps over .items())
* asynchronous iterables
* numpy arrays (uses np.vectorize)
Expand Down Expand Up @@ -1644,10 +1642,11 @@ def fmap(func, obj, **kwargs):
else:
if aiter is not _coconut.NotImplemented:
return _coconut_amap(func, aiter)
if starmap_over_mappings:
return _coconut_base_makedata(obj.__class__, {_coconut_}starmap(func, obj.items()) if _coconut.isinstance(obj, _coconut.abc.Mapping) else {_coconut_}map(func, obj), from_fmap=True, fallback_to_init=fallback_to_init)
{handle_bytes} if _coconut.isinstance(obj, _coconut.abc.Mapping):
mapped_obj = ({_coconut_}starmap if starmap_over_mappings else {_coconut_}map)(func, obj.items())
else:
return _coconut_base_makedata(obj.__class__, {_coconut_}map(func, obj.items() if _coconut.isinstance(obj, _coconut.abc.Mapping) else obj), from_fmap=True, fallback_to_init=fallback_to_init)
mapped_obj = _coconut_map(func, obj)
return _coconut_base_makedata(obj.__class__, mapped_obj, from_fmap=True, fallback_to_init=fallback_to_init)
def _coconut_memoize_helper(maxsize=None, typed=False):
return maxsize, typed
def memoize(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.4"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 18
DEVELOP = 19
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
3 changes: 3 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def primary_test_2() -> bool:
assert bytes(10) == b"\x00" * 10
assert bytes([35, 40]) == b'#('
assert bytes(b"abc") == b"abc" == bytes("abc", "utf-8")
assert b"Abc" |> fmap$(.|32) == b"abc"
assert bytearray(b"Abc") |> fmap$(.|32) == bytearray(b"abc")
assert (bytearray(b"Abc") |> fmap$(.|32)) `isinstance` bytearray

with process_map.multiple_sequential_calls(): # type: ignore
assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore
Expand Down

0 comments on commit b697bc8

Please sign in to comment.