Skip to content

Commit

Permalink
Init field within Field class
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Nov 20, 2024
1 parent 44ebeb6 commit f72aff4
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 38 deletions.
33 changes: 0 additions & 33 deletions src/scippnexus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,45 +41,12 @@ def is_dataset(obj: H5Group | H5Dataset) -> bool:
return hasattr(obj, 'shape')


_scipp_dtype = {
np.dtype('int8'): sc.DType.int32,
np.dtype('int16'): sc.DType.int32,
np.dtype('uint8'): sc.DType.int32,
np.dtype('uint16'): sc.DType.int32,
np.dtype('uint32'): sc.DType.int32,
np.dtype('uint64'): sc.DType.int64,
np.dtype('int32'): sc.DType.int32,
np.dtype('int64'): sc.DType.int64,
np.dtype('float32'): sc.DType.float32,
np.dtype('float64'): sc.DType.float64,
np.dtype('bool'): sc.DType.bool,
}


def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType:
return _scipp_dtype.get(dataset.dtype, sc.DType.string)


def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]:
if (shape := dataset.shape) == (1,):
return {}
return {f'dim_{i}': size for i, size in enumerate(shape)}


class NXobject:
def _init_field(self, field: Field):
if field.sizes is None:
field.sizes = _squeezed_field_sizes(field.dataset)
field.dtype = _dtype_fromdataset(field.dataset)

def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]):
"""Subclasses should call this in their __init__ method, or ensure that they
initialize the fields in `children` with the correct sizes and dtypes."""
self._attrs = attrs
self._children = children
for field in children.values():
if isinstance(field, Field):
self._init_field(field)

@property
def unit(self) -> None | sc.Unit:
Expand Down
31 changes: 31 additions & 0 deletions src/scippnexus/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,31 @@ def _as_datetime(obj: Any):
return None


_scipp_dtype = {
np.dtype('int8'): sc.DType.int32,
np.dtype('int16'): sc.DType.int32,
np.dtype('uint8'): sc.DType.int32,
np.dtype('uint16'): sc.DType.int32,
np.dtype('uint32'): sc.DType.int32,
np.dtype('uint64'): sc.DType.int64,
np.dtype('int32'): sc.DType.int32,
np.dtype('int64'): sc.DType.int64,
np.dtype('float32'): sc.DType.float32,
np.dtype('float64'): sc.DType.float64,
np.dtype('bool'): sc.DType.bool,
}


def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType:
return _scipp_dtype.get(dataset.dtype, sc.DType.string)


def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]:
if (shape := dataset.shape) == (1,):
return {}
return {f'dim_{i}': size for i, size in enumerate(shape)}


@dataclass
class Field:
"""NeXus field.
Expand All @@ -93,6 +118,12 @@ class Field:
dtype: sc.DType | None = None
errors: H5Dataset | None = None

def __post_init__(self) -> None:
if self.sizes is None:
self.sizes = _squeezed_field_sizes(self.dataset)
if self.dtype is None:
self.dtype = _dtype_fromdataset(self.dataset)

@cached_property
def attrs(self) -> dict[str, Any]:
"""The attributes of the dataset.
Expand Down
1 change: 0 additions & 1 deletion src/scippnexus/nxdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]):
for k in list(children):
if k.startswith(name):
field = children.pop(k)
self._init_field(field)
field.sizes = {
'time' if i == 0 else f'dim_{i}': size
for i, size in enumerate(field.dataset.shape)
Expand Down
4 changes: 0 additions & 4 deletions src/scippnexus/nxtransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,9 @@ def _locate_depends_on_target(
target = file[target_path.as_posix()]

if is_dataset(target):
from .base import _dtype_fromdataset, _squeezed_field_sizes

res = Field(
target,
parent=Group(target.parent, definitions=definitions),
sizes=_squeezed_field_sizes(target),
dtype=_dtype_fromdataset(target),
)
else:
res = Group(target, definitions=definitions)
Expand Down

0 comments on commit f72aff4

Please sign in to comment.