Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor evaluating string fields #294

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 44 additions & 20 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@
"""

from collections.abc import Mapping, Sequence
from dataclasses import Field, fields, is_dataclass
from typing import Union, get_args, get_origin
from dataclasses import fields, is_dataclass
from typing import NamedTuple, Union, get_args, get_origin

from arraycontext.container import is_array_container_type


# {{{ dataclass containers

class _Field(NamedTuple):
inducer marked this conversation as resolved.
Show resolved Hide resolved
"""Small lookalike for :class:`dataclasses.Field`."""

init: bool
name: str
type: type


def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)
Expand Down Expand Up @@ -73,7 +81,9 @@ def dataclass_array_container(cls: type) -> type:

assert is_dataclass(cls)

def is_array_field(f: Field, field_type: type) -> bool:
def is_array_field(f: _Field) -> bool:
field_type = f.type

# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
Expand All @@ -96,10 +106,8 @@ def is_array_field(f: Field, field_type: type) -> bool:
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")

if isinstance(field_type, str):
raise TypeError(
f"String annotation on field '{f.name}' not supported. "
"(this may be due to 'from __future__ import annotations')")
# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)

if __debug__:
if not f.init:
Expand Down Expand Up @@ -127,36 +135,52 @@ def is_array_field(f: Field, field_type: type) -> bool:

return is_array_type(field_type)

from pytools import partition

array_fields = _get_annotated_fields(cls)
array_fields, non_array_fields = partition(is_array_field, array_fields)

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")

return _inject_dataclass_serialization(cls, array_fields, non_array_fields)


def _get_annotated_fields(cls: type) -> Sequence[_Field]:
inducer marked this conversation as resolved.
Show resolved Hide resolved
"""Get a list of fields in the class *cls* with evaluated types.

If any of the fields in *cls* have type annotations that are strings, e.g.
from using ``from __future__ import annotations``, this function evaluates
them using :func:`inspect.get_annotations`. Note that this requires the class
to live in a module that is importable.

:return: a list of fields.
"""

from inspect import get_annotations

array_fields: list[Field] = []
non_array_fields: list[Field] = []
result = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)

field_type = cls_ann[field.name]
else:
field_type = field_type_or_str

if is_array_field(field, field_type):
array_fields.append(field)
else:
non_array_fields.append(field)

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")
result.append(_Field(init=field.init, name=field.name, type=field_type))

return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
return result


def _inject_dataclass_serialization(
cls: type,
array_fields: Sequence[Field],
non_array_fields: Sequence[Field]) -> type:
array_fields: Sequence[_Field],
non_array_fields: Sequence[_Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.

Expand Down
Loading