Skip to content

Commit

Permalink
Merge pull request #253 from scipp/opt-depends_on-resolution
Browse files Browse the repository at this point in the history
Optimise depends_on chain resolution by using h5py
  • Loading branch information
jl-wynen authored Nov 25, 2024
2 parents 349c8df + 2138887 commit ec20200
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 42 deletions.
45 changes: 11 additions & 34 deletions src/scippnexus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,45 +41,12 @@ def is_dataset(obj: H5Group | H5Dataset) -> bool:
return hasattr(obj, 'shape')


_scipp_dtype = {
np.dtype('int8'): sc.DType.int32,
np.dtype('int16'): sc.DType.int32,
np.dtype('uint8'): sc.DType.int32,
np.dtype('uint16'): sc.DType.int32,
np.dtype('uint32'): sc.DType.int32,
np.dtype('uint64'): sc.DType.int64,
np.dtype('int32'): sc.DType.int32,
np.dtype('int64'): sc.DType.int64,
np.dtype('float32'): sc.DType.float32,
np.dtype('float64'): sc.DType.float64,
np.dtype('bool'): sc.DType.bool,
}


def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType:
return _scipp_dtype.get(dataset.dtype, sc.DType.string)


def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]:
if (shape := dataset.shape) == (1,):
return {}
return {f'dim_{i}': size for i, size in enumerate(shape)}


class NXobject:
def _init_field(self, field: Field):
if field.sizes is None:
field.sizes = _squeezed_field_sizes(field.dataset)
field.dtype = _dtype_fromdataset(field.dataset)

def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]):
"""Subclasses should call this in their __init__ method, or ensure that they
initialize the fields in `children` with the correct sizes and dtypes."""
self._attrs = attrs
self._children = children
for field in children.values():
if isinstance(field, Field):
self._init_field(field)

@property
def unit(self) -> None | sc.Unit:
Expand Down Expand Up @@ -222,7 +189,7 @@ def nx_class(self) -> type | None:
return NXroot

@cached_property
def attrs(self) -> dict[str, Any]:
def attrs(self) -> MappingProxyType[str, Any]:
"""The attributes of the group.
Cannot be used for writing attributes, since they are cached for performance."""
Expand Down Expand Up @@ -479,6 +446,16 @@ def dims(self) -> tuple[str, ...]:
def shape(self) -> tuple[int, ...]:
return tuple(self.sizes.values())

@property
def definitions(self) -> MappingProxyType[str, str | type] | None:
return (
None if self._definitions is None else MappingProxyType(self._definitions)
)

@property
def underlying(self) -> H5Group:
return self._group


def _create_field_params_numpy(data: np.ndarray):
return data, None, {}
Expand Down
31 changes: 31 additions & 0 deletions src/scippnexus/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,31 @@ def _as_datetime(obj: Any):
return None


_scipp_dtype = {
np.dtype('int8'): sc.DType.int32,
np.dtype('int16'): sc.DType.int32,
np.dtype('uint8'): sc.DType.int32,
np.dtype('uint16'): sc.DType.int32,
np.dtype('uint32'): sc.DType.int32,
np.dtype('uint64'): sc.DType.int64,
np.dtype('int32'): sc.DType.int32,
np.dtype('int64'): sc.DType.int64,
np.dtype('float32'): sc.DType.float32,
np.dtype('float64'): sc.DType.float64,
np.dtype('bool'): sc.DType.bool,
}


def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType:
return _scipp_dtype.get(dataset.dtype, sc.DType.string)


def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]:
if (shape := dataset.shape) == (1,):
return {}
return {f'dim_{i}': size for i, size in enumerate(shape)}


@dataclass
class Field:
"""NeXus field.
Expand All @@ -93,6 +118,12 @@ class Field:
dtype: sc.DType | None = None
errors: H5Dataset | None = None

def __post_init__(self) -> None:
if self.sizes is None:
self.sizes = _squeezed_field_sizes(self.dataset)
if self.dtype is None:
self.dtype = _dtype_fromdataset(self.dataset)

@cached_property
def attrs(self) -> dict[str, Any]:
"""The attributes of the dataset.
Expand Down
1 change: 0 additions & 1 deletion src/scippnexus/nxdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]):
for k in list(children):
if k.startswith(name):
field = children.pop(k)
self._init_field(field)
field.sizes = {
'time' if i == 0 else f'dim_{i}': size
for i, size in enumerate(field.dataset.shape)
Expand Down
49 changes: 43 additions & 6 deletions src/scippnexus/nxtransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,24 @@

from __future__ import annotations

import posixpath
import warnings
from collections.abc import Mapping
from dataclasses import dataclass, field, replace
from typing import Literal

import h5py
import numpy as np
import scipp as sc
from scipp.scipy import interpolate

from .base import Group, NexusStructureError, NXobject, base_definitions_dict
from .base import (
Group,
NexusStructureError,
NXobject,
base_definitions_dict,
is_dataset,
)
from .field import DependsOn, Field


Expand Down Expand Up @@ -266,17 +275,45 @@ def compute(self) -> sc.Variable | sc.DataArray:
return transform


def _locate_depends_on_target(
file: h5py.File,
depends_on: DependsOn,
definitions: Mapping[str, type] | None,
) -> tuple[Field | Group, str]:
"""Find the target of a depends_on link.
The returned object is equivalent to calling ``parent[depends_on]``
in the context of transformations.
This function does not work in general because it does not process any attributes
of parents which is required to fully load some groups.
"""
target_path = depends_on.absolute_path()
target = file[target_path]

if is_dataset(target):
res = Field(
target,
parent=Group(target.parent, definitions=definitions),
)
else:
res = Group(target, definitions=definitions)
return res, posixpath.dirname(target_path)


def parse_depends_on_chain(
parent: Field | Group, depends_on: DependsOn
) -> TransformationChain | None:
"""Follow a depends_on chain and return the transformations."""
chain = TransformationChain(depends_on.parent, depends_on.value)
depends_on = depends_on.value
# Use raw h5py objects to follow the chain because that avoids constructing
# expensive intermediate snx.Group objects.
file = parent.underlying.file
try:
while depends_on != '.':
transform = parent[depends_on]
parent = transform.parent
depends_on = transform.attrs['depends_on']
while depends_on.value != '.':
transform, base = _locate_depends_on_target(
file, depends_on, parent.definitions
)
depends_on = DependsOn(parent=base, value=transform.attrs['depends_on'])
chain.transformations[transform.name] = transform[()]
except KeyError as e:
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion src/scippnexus/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def name(self) -> str:
"""Name of dataset or group"""

@property
def file(self) -> list[int]:
def file(self) -> Any:
"""File of dataset or group"""

@property
Expand Down

0 comments on commit ec20200

Please sign in to comment.