Skip to content

Commit

Permalink
Add or_else, result_or_else to Expected
Browse files Browse the repository at this point in the history
Refs   #691.
  • Loading branch information
evhub committed Dec 25, 2022
1 parent 5bd0e63 commit e7a2316
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 29 deletions.
14 changes: 10 additions & 4 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2769,8 +2769,8 @@ Coconut's `Expected` built-in is a Coconut [`data` type](#data) that represents

`Expected` is effectively equivalent to the following:
```coconut
data Expected[T](result: T?, error: Exception?):
def __new__(cls, result: T?=None, error: Exception?=None) -> Expected[T]:
data Expected[T](result: T?, error: BaseException?):
def __new__(cls, result: T?=None, error: BaseException?=None) -> Expected[T]:
if result is not None and error is not None:
raise TypeError("Expected cannot have both a result and an error")
return makedata(cls, result, error)
Expand All @@ -2787,11 +2787,17 @@ data Expected[T](result: T?, error: Exception?):
if not self:
return self
if not self.result `isinstance` Expected:
raise TypeError("Expected.join() requires an Expected[Expected[T]]")
raise TypeError("Expected.join() requires an Expected[Expected[_]]")
return self.result
def result_or[U](self, default: U) -> Expected(T | U):
def or_else[U](self, func: BaseException -> Expected[U]) -> Expected[T | U]:
"""Return self if no error, otherwise return the result of evaluating func on the error."""
return self if self else func(self.error)
def result_or[U](self, default: U) -> T | U:
"""Return the result if it exists, otherwise return the default."""
return self.result if self else default
def result_or_else[U](self, func: BaseException -> U) -> T | U:
"""Return the result if it exists, otherwise return the result of evaluating func on the error."""
return self.result if self else func(self.error)
def unwrap(self) -> T:
"""Unwrap the result or raise the error."""
if not self:
Expand Down
16 changes: 9 additions & 7 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ _coconut_tail_call = of = call
@_dataclass(frozen=True, slots=True)
class Expected(_t.Generic[_T], _t.Tuple):
result: _t.Optional[_T]
error: _t.Optional[Exception]
error: _t.Optional[BaseException]
@_t.overload
def __new__(
cls,
Expand All @@ -278,28 +278,30 @@ class Expected(_t.Generic[_T], _t.Tuple):
cls,
result: None = None,
*,
error: Exception,
error: BaseException,
) -> Expected[_t.Any]: ...
@_t.overload
def __new__(
cls,
result: None,
error: Exception,
error: BaseException,
) -> Expected[_t.Any]: ...
def __init__(
self,
result: _t.Optional[_T] = None,
error: _t.Optional[Exception] = None,
error: _t.Optional[BaseException] = None,
): ...
def __fmap__(self, func: _t.Callable[[_T], _U]) -> Expected[_U]: ...
def __iter__(self) -> _t.Iterator[_T | Exception | None]: ...
def __iter__(self) -> _t.Iterator[_T | BaseException | None]: ...
@_t.overload
def __getitem__(self, index: _SupportsIndex) -> _T | Exception | None: ...
def __getitem__(self, index: _SupportsIndex) -> _T | BaseException | None: ...
@_t.overload
def __getitem__(self, index: slice) -> _t.Tuple[_T | Exception | None, ...]: ...
def __getitem__(self, index: slice) -> _t.Tuple[_T | BaseException | None, ...]: ...
def and_then(self, func: _t.Callable[[_T], Expected[_U]]) -> Expected[_U]: ...
def join(self: Expected[Expected[_T]]) -> Expected[_T]: ...
def or_else(self, func: _t.Callable[[BaseException], Expected[_U]]) -> Expected[_T | _U]: ...
def result_or(self, default: _U) -> _T | _U: ...
def result_or_else(self, func: _t.Callable[[BaseException], _U]) -> _T | _U: ...
def unwrap(self) -> _T: ...

_coconut_Expected = Expected
Expand Down
39 changes: 28 additions & 11 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -1521,8 +1521,8 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){
that may or may not be an error, similar to Haskell's Either.

Effectively equivalent to:
data Expected[T](result: T?, error: Exception?):
def __new__(cls, result: T?=None, error: Exception?=None) -> Expected[T]:
data Expected[T](result: T?, error: BaseException?):
def __new__(cls, result: T?=None, error: BaseException?=None) -> Expected[T]:
if result is not None and error is not None:
raise TypeError("Expected cannot have both a result and an error")
return makedata(cls, result, error)
Expand All @@ -1539,11 +1539,17 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){
if not self:
return self
if not self.result `isinstance` Expected:
raise TypeError("Expected.join() requires an Expected[Expected[T]]")
raise TypeError("Expected.join() requires an Expected[Expected[_]]")
return self.result
def result_or[U](self, default: U) -> Expected(T | U):
def or_else[U](self, func: BaseException -> Expected[U]) -> Expected[T | U]:
"""Return self if no error, otherwise return the result of evaluating func on the error."""
return self if self else func(self.error)
def result_or[U](self, default: U) -> T | U:
"""Return the result if it exists, otherwise return the default."""
return self.result if self else default
def result_or_else[U](self, func: BaseException -> U) -> T | U:
"""Return the result if it exists, otherwise return the result of evaluating func on the error."""
return self.result if self else func(self.error)
def unwrap(self) -> T:
"""Unwrap the result or raise the error."""
if not self:
Expand All @@ -1569,24 +1575,35 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){
if result is _coconut_sentinel:
result = None
return _coconut.tuple.__new__(cls, (result, error))
def __fmap__(self, func):
return self if not self else self.__class__(func(self.result))
def __bool__(self):
return self.error is None
def __fmap__(self, func):
return self if not self else self.__class__(func(self.result))
def and_then(self, func):
"""Maps a T -> Expected[U] over an Expected[T] to produce an Expected[U].
Implements a monadic bind. Equivalent to fmap ..> .join()."""
return self.__fmap__(func).join()
def join(self):
"""Monadic join. Converts Expected[Expected[T]] to Expected[T]."""
if not self:
return self
if not _coconut.isinstance(self.result, _coconut_Expected):
raise _coconut.TypeError("Expected.join() requires an Expected[Expected[T]]")
raise _coconut.TypeError("Expected.join() requires an Expected[Expected[_]]")
return self.result
def and_then(self, func):
"""Maps a T -> Expected[U] over an Expected[T] to produce an Expected[U].
Implements a monadic bind. Equivalent to fmap ..> .join()."""
return self.__fmap__(func).join()
def or_else(self, func):
"""Return self if no error, otherwise return the result of evaluating func on the error."""
if self:
return self
got = func(self.error)
if not _coconut.isinstance(got, _coconut_Expected):
raise _coconut.TypeError("Expected.or_else() requires a function that returns an Expected")
return got
def result_or(self, default):
"""Return the result if it exists, otherwise return the default."""
return self.result if self else default
def result_or_else(self, func):
"""Return the result if it exists, otherwise return the result of evaluating func on the error."""
return self.result if self else func(self.error)
def unwrap(self):
"""Unwrap the result or raise the error."""
if not self:
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "2.1.1"
VERSION_NAME = "The Spanish Inquisition"
# False for release, int >= 1 for develop
DEVELOP = 42
DEVELOP = 43
ALPHA = False # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
18 changes: 12 additions & 6 deletions coconut/tests/src/cocotest/agnostic/main.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1267,8 +1267,9 @@ def main_test() -> bool:
ys = (_ for _ in range(2)) :: (_ for _ in range(2))
assert ys |> list == [0, 1, 0, 1]
assert ys |> list == []
assert Expected(10) |> fmap$(.+1) == Expected(11)

some_err = ValueError()
assert Expected(10) |> fmap$(.+1) == Expected(11)
assert Expected(error=some_err) |> fmap$(.+1) == Expected(error=some_err)
res, err = Expected(10)
assert (res, err) == (10, None)
Expand All @@ -1284,6 +1285,16 @@ def main_test() -> bool:
assert Expected(error=some_err).and_then(safe_call$(.*2)) == Expected(error=some_err)
assert Expected(Expected(10)).join() == Expected(10)
assert Expected(error=some_err).join() == Expected(error=some_err)
assert_raises(Expected, TypeError)
assert Expected(10).result_or(0) == 10 == Expected(error=TypeError()).result_or(10)
assert Expected(10).result_or_else(const 0) == 10 == Expected(error=TypeError()).result_or_else(const 10)
assert Expected(error=some_err).result_or_else(ident) is some_err
assert Expected(None)
assert Expected(10).unwrap() == 10
assert_raises(Expected(error=TypeError()).unwrap, TypeError)
assert_raises(Expected(error=KeyboardInterrupt()).unwrap, KeyboardInterrupt)
assert Expected(10).or_else(const <| Expected(20)) == Expected(10) == Expected(error=TypeError()).or_else(const <| Expected(10))

recit = ([1,2,3] :: recit) |> map$(.+1)
assert tee(recit)
rawit = (_ for _ in (0, 1))
Expand Down Expand Up @@ -1399,11 +1410,6 @@ def main_test() -> bool:
hardref = map((.+1), [1,2,3])
assert weakref.ref(hardref)() |> list == [2, 3, 4]
assert parallel_map(ident, [MatchError]) |> list == [MatchError]
assert_raises(Expected, TypeError)
assert Expected(10).result_or(0) == 10 == Expected(error=TypeError()).result_or(10)
assert Expected(None)
assert Expected(10).unwrap() == 10
assert_raises(Expected(error=TypeError()).unwrap, TypeError)
return True

def test_asyncio() -> bool:
Expand Down

0 comments on commit e7a2316

Please sign in to comment.