Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements a Numpy-based ArrayContext #93

Closed
wants to merge 8 commits into from
5 changes: 5 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@

from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoPyOpenCLArrayContext
from .impl.numpy import NumpyArrayContext

from .pytest import (
PytestArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
pytest_generate_tests_for_pyopencl_array_context)
Expand Down Expand Up @@ -100,8 +102,11 @@

"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",

"NumpyArrayContext",

"make_loopy_program",

"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
"pytest_generate_tests_for_array_contexts",
"pytest_generate_tests_for_pyopencl_array_context"
Expand Down
6 changes: 5 additions & 1 deletion arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ def is_array_container(ary: Any) -> bool:
"cheaper option, see is_array_container_type.",
DeprecationWarning, stacklevel=2)
return (serialize_container.dispatch(ary.__class__)
is not serialize_container.__wrapped__) # type:ignore[attr-defined]
is not serialize_container.__wrapped__ # type:ignore[attr-defined]
# numpy values with scalar elements aren't array containers
and not (isinstance(ary, np.ndarray)
and ary.dtype.kind != "O")
)


@singledispatch
Expand Down
32 changes: 21 additions & 11 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
if rel_comparison is None:
raise TypeError("rel_comparison must be specified")

if bcast_numpy_array:
from warnings import warn
warn("'bcast_numpy_array=True' is deprecated and will be unsupported"
" from December 2021", DeprecationWarning, stacklevel=2)

if _bcast_actx_array_type:
raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
" cannot be both set.")

if rel_comparison and eq_comparison is None:
eq_comparison = True

Expand All @@ -216,7 +225,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")

if _bcast_actx_array_type is None:
if _cls_has_array_context_attr:
if _cls_has_array_context_attr and (not bcast_numpy_array):
_bcast_actx_array_type = bcast_number
else:
_bcast_actx_array_type = False
Expand Down Expand Up @@ -395,16 +404,17 @@ def {fname}(arg1):
bcast_actx_ary_types = ()

gen(f"""
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_same_cls_init_args})
if {numpy_pred("arg2")}:
result = np.empty_like(arg2, dtype=object)
for i in np.ndindex(arg2.shape):
result[i] = {op_str.format("arg1", "arg2[i]")}
return result

if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_same_cls_init_args})
return NotImplemented
""")
gen(f"cls.__{dunder_name}__ = {fname}")
Expand Down Expand Up @@ -436,16 +446,16 @@ def {fname}(arg1):
def {fname}(arg2, arg1):
# assert other.__cls__ is not cls

if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_init_args})
if {numpy_pred("arg1")}:
result = np.empty_like(arg1, dtype=object)
for i in np.ndindex(arg1.shape):
result[i] = {op_str.format("arg1[i]", "arg2")}
return result
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_init_args})
return NotImplemented

cls.__r{dunder_name}__ = {fname}""")
Expand Down
124 changes: 124 additions & 0 deletions arraycontext/impl/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
.. currentmodule:: arraycontext


A mod :`numpy`-based array context.

.. autoclass:: NumpyArrayContext
"""
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from arraycontext.context import ArrayContext
import numpy as np
import loopy as lp
from typing import Union, Sequence, Dict
from pytools.tag import Tag


class NumpyArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`numpy.ndarray` to represent arrays


.. automethod:: __init__
"""
def __init__(self):
super().__init__()
self._loopy_transform_cache: \
Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {}

self.array_types = (np.ndarray,)

def _get_fake_numpy_namespace(self):
from .fake_numpy import NumpyFakeNumpyNamespace
return NumpyFakeNumpyNamespace(self)

# {{{ ArrayContext interface

def clone(self):
return type(self)()

def empty(self, shape, dtype):
return np.empty(shape, dtype=dtype)

def zeros(self, shape, dtype):
return np.zeros(shape, dtype)

def from_numpy(self, np_array: np.ndarray):
# Uh oh...
return np_array

def to_numpy(self, array):
# Uh oh...
return array

def call_loopy(self, t_unit, **kwargs):
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
t_unit = self._loopy_transform_cache[t_unit]
except KeyError:
orig_t_unit = t_unit
t_unit = self.transform_loopy_program(t_unit)
self._loopy_transform_cache[orig_t_unit] = t_unit
del orig_t_unit

_, result = t_unit(**kwargs)

return result

def freeze(self, array):
return array

def thaw(self, array):
return array

# }}}

def transform_loopy_program(self, t_unit):
raise ValueError("NumpyArrayContext does not implement "
"transform_loopy_program. Sub-classes are supposed "
"to implement it.")

def tag(self, tags: Union[Sequence[Tag], Tag], array):
# Numpy doesn't support tagging
return array

def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
return array

def einsum(self, spec, *args, arg_names=None, tagged=()):
return np.einsum(spec, *args)

@property
def permits_inplace_modification(self):
return True

@property
def supports_nonscalar_broadcasting(self):
return True

@property
def permits_advanced_indexing(self):
return True
142 changes: 142 additions & 0 deletions arraycontext/impl/numpy/fake_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce

from arraycontext.fake_numpy import (
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
)
from arraycontext.container import is_array_container
from arraycontext.container.traversal import (
rec_map_array_container,
rec_multimap_array_container,
multimap_reduce_array_container,
rec_map_reduce_array_container,
rec_multimap_reduce_array_container,
)
import numpy as np


class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass


_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh", "exp", "log", "log10", "isnan",
"sqrt", "exp", "concatenate", "reshape", "transpose",
"ones_like", "maximum", "minimum", "where", "conj", "arctan2",
}


class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`NumpyArrayContext`.
"""
def _get_fake_numpy_linalg_namespace(self):
return NumpyFakeNumpyLinalgNamespace(self._array_context)

def __getattr__(self, name):

if name in _NUMPY_UFUNCS:
from functools import partial
return partial(rec_multimap_array_container,
getattr(np, name))

return super().__getattr__(name)

def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(sum, partial(np.sum,
axis=axis,
dtype=dtype),
a)

def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, np.minimum), partial(np.amin, axis=axis), a)

def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, np.maximum), partial(np.amax, axis=axis), a)

def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: np.stack(arrays=args, axis=axis),
*arrays)

def broadcast_to(self, array, shape):
return rec_map_array_container(partial(np.broadcast_to, shape=shape), array)

# {{{ relational operators

def equal(self, x, y):
return rec_multimap_array_container(np.equal, x, y)

def not_equal(self, x, y):
return rec_multimap_array_container(np.not_equal, x, y)

def greater(self, x, y):
return rec_multimap_array_container(np.greater, x, y)

def greater_equal(self, x, y):
return rec_multimap_array_container(np.greater_equal, x, y)

def less(self, x, y):
return rec_multimap_array_container(np.less, x, y)

def less_equal(self, x, y):
return rec_multimap_array_container(np.less_equal, x, y)

# }}}

def ravel(self, a, order="C"):
return rec_map_array_container(partial(np.ravel, order=order), a)

def vdot(self, x, y, dtype=None):
if dtype is not None:
raise NotImplementedError("only 'dtype=None' supported.")

return rec_multimap_reduce_array_container(sum, np.vdot, x, y)

def any(self, a):
return rec_map_reduce_array_container(partial(reduce, np.logical_or),
lambda subary: np.any(subary), a)

def all(self, a):
return rec_map_reduce_array_container(partial(reduce, np.logical_and),
lambda subary: np.all(subary), a)

def array_equal(self, a, b):
if type(a) != type(b):
return False
elif not is_array_container(a):
if a.shape != b.shape:
return False
else:
return np.all(np.equal(a, b))
else:
return multimap_reduce_array_container(partial(reduce,
np.logical_and),
self.array_equal, a, b)

# vim: fdm=marker
3 changes: 3 additions & 0 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def to_numpy(self, array):
if np.isscalar(array):
return array

if not isinstance(array, self.array_types):
raise TypeError(f"to_numpy called on {type(array)}.")

return array.get(queue=self.queue)

def call_loopy(self, t_unit, **kwargs):
Expand Down
Loading