From b28d5296ef96a0673c425826cdbdaa3f35dba68c Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 11 May 2024 15:36:38 +0000 Subject: [PATCH 01/19] Update [ghstack-poisoned] --- src/mrpro/data/_Data.py | 46 +--------- src/mrpro/data/_DcfData.py | 3 +- src/mrpro/data/_KNoise.py | 6 +- src/mrpro/data/_KTrajectory.py | 78 +++++++++-------- src/mrpro/data/_MoveDataMixin.py | 109 ++++++++++++++++++++++++ src/mrpro/data/_QData.py | 2 + src/mrpro/data/__init__.py | 1 + src/mrpro/data/_kdata/_KData.py | 141 ++++++++++++++----------------- 8 files changed, 223 insertions(+), 163 deletions(-) create mode 100644 src/mrpro/data/_MoveDataMixin.py diff --git a/src/mrpro/data/_Data.py b/src/mrpro/data/_Data.py index c1e9379fd..370565d0b 100644 --- a/src/mrpro/data/_Data.py +++ b/src/mrpro/data/_Data.py @@ -22,52 +22,12 @@ import torch +from mrpro.data._MoveDataMixin import MoveDataMixin + @dataclasses.dataclass(slots=True, frozen=True) -class Data(ABC): +class Data(MoveDataMixin, ABC): """A general data class with field data and header.""" data: torch.Tensor header: Any - - def to(self, *args, **kwargs) -> Data: - """Perform dtype and/or device conversion of data. - - A torch.dtype and torch.device are inferred from the arguments - of self.to(*args, **kwargs). Please have a look at the - documentation of torch.Tensor.to() for more details. - """ - return Data(header=self.header, data=self.data.to(*args, **kwargs)) - - def cuda( - self, - device: torch.device | None = None, - non_blocking: bool = False, - memory_format: torch.memory_format = torch.preserve_format, - ) -> Data: - """Create copy of object with data in CUDA memory. - - Parameters - ---------- - device - The destination GPU device. Defaults to the current CUDA device. - non_blocking - If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. - Otherwise, the argument has no effect. - memory_format - The desired memory format of returned tensor. - """ - return Data( - header=self.header, - data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] - ) - - def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Data: - """Create copy of object in CPU memory. - - Parameters - ---------- - memory_format - The desired memory format of returned tensor. - """ - return Data(header=self.header, data=self.data.cpu(memory_format=memory_format)) # type: ignore [call-arg] diff --git a/src/mrpro/data/_DcfData.py b/src/mrpro/data/_DcfData.py index 42b79ed65..8c40214ac 100644 --- a/src/mrpro/data/_DcfData.py +++ b/src/mrpro/data/_DcfData.py @@ -29,6 +29,7 @@ from scipy.spatial import Voronoi from mrpro.data._KTrajectory import KTrajectory +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.utils import smap if TYPE_CHECKING: @@ -42,7 +43,7 @@ def _volume(v: ArrayLike): @dataclasses.dataclass(slots=True, frozen=False) -class DcfData: +class DcfData(MoveDataMixin): """Density compensation data (DcfData) class.""" data: torch.Tensor diff --git a/src/mrpro/data/_KNoise.py b/src/mrpro/data/_KNoise.py index 0afe7b485..0dad8d2c6 100644 --- a/src/mrpro/data/_KNoise.py +++ b/src/mrpro/data/_KNoise.py @@ -107,7 +107,7 @@ def cuda( non_blocking: bool = False, memory_format: torch.memory_format = torch.preserve_format, ) -> Self: - """Create copy of object with trajectory and data in CUDA memory. + """Create copy of object with data in CUDA memory. Parameters ---------- @@ -120,7 +120,7 @@ def cuda( The desired memory format of returned Tensor. """ return type(self)( - data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] + data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), ) def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Self: @@ -132,5 +132,5 @@ def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Sel The desired memory format of returned Tensor. """ return type(self)( - data=self.data.cpu(memory_format=memory_format), # type: ignore [call-arg] + data=self.data.cpu(memory_format=memory_format), ) diff --git a/src/mrpro/data/_KTrajectory.py b/src/mrpro/data/_KTrajectory.py index 061a47f53..d353d325c 100644 --- a/src/mrpro/data/_KTrajectory.py +++ b/src/mrpro/data/_KTrajectory.py @@ -17,16 +17,19 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Self +from typing import overload import numpy as np import torch +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data.enums import TrajType from mrpro.utils import remove_repeat @dataclass(slots=True, frozen=True) -class KTrajectory: +class KTrajectory(MoveDataMixin): """K-space trajectory. Order of directions is always kz, ky, kx @@ -182,53 +185,48 @@ def as_tensor(self, stack_dim: int = 0) -> torch.Tensor: shape = self.broadcasted_shape return torch.stack([traj.expand(*shape) for traj in (self.kz, self.ky, self.kx)], dim=stack_dim) - def to(self, *args, **kwargs) -> KTrajectory: + @overload + def to( + self, dtype: torch.dtype, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + ) -> Self: ... + + @overload + def to( + self, + device: str | torch.device | int | None = None, + dtype: torch.dtype | None = None, + non_blocking: bool = False, + *, + memory_format: torch.memory_format | None = None, + ) -> Self: ... + + @overload + def to( + self, other: torch.Tensor, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + ) -> Self: ... + + def to(self, *args, **kwargs) -> Self: """Perform dtype and/or device conversion of trajectory. + This will always return a new KTrajectory object. + A torch.dtype and torch.device are inferred from the arguments of self.to(*args, **kwargs). Please have a look at the documentation of torch.Tensor.to() for more details. """ - return KTrajectory( + kwargs['copy'] = True + return type(self)( kz=self.kz.to(*args, **kwargs), ky=self.ky.to(*args, **kwargs), kx=self.kx.to(*args, **kwargs), ) - def cuda( - self, - device: torch.device | None = None, - non_blocking: bool = False, - memory_format: torch.memory_format = torch.preserve_format, - ) -> KTrajectory: - """Create copy of trajectory in CUDA memory. - - Parameters - ---------- - device - The destination GPU device. Defaults to the current CUDA device. - non_blocking - If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. - Otherwise, the argument has no effect. - memory_format - The desired memory format of returned Tensor. - """ - return KTrajectory( - kz=self.kz.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] - ky=self.ky.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] - kx=self.kx.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] - ) - - def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> KTrajectory: - """Create copy of trajectory in CPU memory. - - Parameters - ---------- - memory_format - The desired memory format of returned Tensor. - """ - return KTrajectory( - kz=self.kz.cpu(memory_format=memory_format), # type: ignore [call-arg] - ky=self.ky.cpu(memory_format=memory_format), # type: ignore [call-arg] - kx=self.kx.cpu(memory_format=memory_format), # type: ignore [call-arg] - ) + @property + def device(self) -> torch.device: + """Return the device of the trajectory.""" + device_x = self.kx.device + device_y = self.ky.device + device_z = self.kz.device + if device_x != device_y or device_x != device_z: + raise ValueError('Trajectory is on different devices.') + return device_x diff --git a/src/mrpro/data/_MoveDataMixin.py b/src/mrpro/data/_MoveDataMixin.py new file mode 100644 index 000000000..6ac2bd998 --- /dev/null +++ b/src/mrpro/data/_MoveDataMixin.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import dataclasses +from abc import ABC +from copy import deepcopy +from typing import Any +from typing import ClassVar +from typing import Protocol +from typing import Self +from typing import overload + +import torch + + +class DataclassInstance(Protocol): + """An instance of a dataclass.""" + + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + + +class MoveDataMixin(ABC, DataclassInstance): + """Move dataclass fields to cpu/gpu and convert dtypes.""" + + data: torch.Tensor + + @overload + def to( + self, dtype: torch.dtype, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + ) -> Self: ... + + @overload + def to( + self, + device: str | torch.device | int | None = None, + dtype: torch.dtype | None = None, + non_blocking: bool = False, + *, + memory_format: torch.memory_format | None = None, + ) -> Self: ... + + @overload + def to( + self, other: torch.Tensor, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + ) -> Self: ... + + def to(self, *args, **kwargs) -> Self: + """Perform dtype and/or device conversion of data. + + This will always return a new Data object. + + A torch.dtype and torch.device are inferred from the arguments + of self.to(*args, **kwargs). Please have a look at the + documentation of torch.Tensor.to() for more details. + """ + kwargs_tensors = {**kwargs, 'copy': True} + new_data: dict[str, Any] = {} + for field in dataclasses.fields(self): + name = field.name + data = getattr(self, name) + if isinstance(data, torch.Tensor): + new_data[name] = data.to(*args, **kwargs_tensors) + elif isinstance(data, MoveDataMixin): + new_data[name] = data.to(*args, **kwargs) + else: + new_data[name] = deepcopy(data) + return type(self)(**new_data) + + def cuda( + self, + device: torch.device | str | int | None = None, + non_blocking: bool = False, + memory_format: torch.memory_format = torch.preserve_format, + ) -> Self: + """Create copy of object with data in CUDA memory. + + This will always return a copy. + + + Parameters + ---------- + device + The destination GPU device. Defaults to the current CUDA device. + non_blocking + If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. + memory_format + The desired memory format of returned tensor. + """ + if device is None: + device = torch.cuda.current_device() + return self.to(device=device, memory_format=memory_format, non_blocking=non_blocking) + + def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Self: + """Create copy of object in CPU memory. + + This will always return a copy. + + + Parameters + ---------- + memory_format + The desired memory format of returned tensor. + """ + return self.to(device='cpu', memory_format=memory_format) + + @property + def device(self) -> torch.device: + """Return the device of the data tensor.""" + return self.data.device diff --git a/src/mrpro/data/_QData.py b/src/mrpro/data/_QData.py index 6fc8e8679..d489a4835 100644 --- a/src/mrpro/data/_QData.py +++ b/src/mrpro/data/_QData.py @@ -52,6 +52,8 @@ def __init__(self, data: torch.Tensor, header: KHeader | IHeader | QHeader) -> N qheader = QHeader.from_iheader(header) elif isinstance(header, QHeader): qheader = header + else: + raise ValueError(f'Invalid header type: {type(header)}') object.__setattr__(self, 'data', data) object.__setattr__(self, 'header', qheader) diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index ce91c6016..e715ece1f 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -14,6 +14,7 @@ from mrpro.data._KNoise import KNoise from mrpro.data._KTrajectory import KTrajectory from mrpro.data._KTrajectoryRawShape import KTrajectoryRawShape +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data._QData import QData from mrpro.data._QHeader import QHeader from mrpro.data._SpatialDimension import SpatialDimension diff --git a/src/mrpro/data/_kdata/_KData.py b/src/mrpro/data/_kdata/_KData.py index 80c82721c..2b6e23c8b 100644 --- a/src/mrpro/data/_kdata/_KData.py +++ b/src/mrpro/data/_kdata/_KData.py @@ -18,8 +18,13 @@ import dataclasses import datetime +from collections.abc import Sequence +from copy import deepcopy from pathlib import Path +from typing import Any from typing import Protocol +from typing import Self +from typing import overload import h5py import ismrmrd @@ -35,6 +40,7 @@ from mrpro.data._KHeader import KHeader from mrpro.data._KTrajectory import KTrajectory from mrpro.data._KTrajectoryRawShape import KTrajectoryRawShape +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data.enums import AcqFlags from mrpro.data.traj_calculators import KTrajectoryCalculator from mrpro.data.traj_calculators import KTrajectoryIsmrmrd @@ -56,7 +62,7 @@ @dataclasses.dataclass(slots=True, frozen=True) -class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveOsMixin): +class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveOsMixin, MoveDataMixin): """MR raw data / k-space data class.""" header: KHeader @@ -258,97 +264,80 @@ def reshape_acq_data(data: torch.Tensor): return cls(kheader, kdata, ktraj) + @overload + def to( + self, dtype: torch.dtype, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + ) -> Self: ... + + @overload def to( self, - device: torch.device | str | None = None, - dtype: None | torch.dtype = None, + device: str | torch.device | int | None = None, + dtype: torch.dtype | None = None, non_blocking: bool = False, - copy: bool = False, - memory_format: torch.memory_format = torch.preserve_format, - ) -> KData: + *, + memory_format: torch.memory_format | None = None, + ) -> Self: ... + + @overload + def to( + self, other: torch.Tensor, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + ) -> Self: ... + + def to(self, *args, **kwargs) -> Self: # noqa: D417 """Perform dtype and/or device conversion of trajectory and data. + This will always return a copy. + Parameters ---------- device The destination device. Defaults to the current device. dtype - Dtype of the k-space data, can only be torch.complex64 or torch.complex128. - The dtype of the trajectory (torch.float32 or torch.float64) is then inferred from this. - non_blocking - If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. - Otherwise, the argument has no effect. - copy - If True a new Tensor is created even when the Tensor already matches the desired conversion. - memory_format - The desired memory format of returned Tensor. - """ - # Only complex64 and complex128 supported for kdata. - # This will then lead to a trajectory of float32 and float64, respectively. - if dtype is None: - dtype_traj = None - elif dtype == torch.complex64: - dtype_traj = torch.float32 - elif dtype == torch.complex128: - dtype_traj = torch.float64 - else: - raise ValueError(f'dtype {dtype} not supported. Only torch.complex64 and torch.complex128 is supported.') - - return KData( - header=self.header, - data=self.data.to( - device=device, - dtype=dtype, - non_blocking=non_blocking, - copy=copy, - memory_format=memory_format, - ), - traj=self.traj.to( - device=device, - dtype=dtype_traj, - non_blocking=non_blocking, - copy=copy, - memory_format=memory_format, - ), - ) - - def cuda( - self, - device: torch.device | None = None, - non_blocking: bool = False, - memory_format: torch.memory_format = torch.preserve_format, - ) -> KData: - """Create copy of object with trajectory and data in CUDA memory. - - Parameters - ---------- - device - The destination GPU device. Defaults to the current CUDA device. + Data type. + The trajectory dtype will always be converted to real, + the data dtype will always be converted to complex non_blocking If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. memory_format The desired memory format of returned Tensor. """ - return KData( - header=self.header, - data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] - traj=self.traj.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), - ) - - def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> KData: - """Create copy of object in CPU memory. - - Parameters - ---------- - memory_format - The desired memory format of returned Tensor. - """ - return KData( - header=self.header, - data=self.data.cpu(memory_format=memory_format), # type: ignore [call-arg] - traj=self.traj.cpu(memory_format=memory_format), - ) + _args: Sequence[Any] = () + _kwargs: dict[str, Any] = {} + dtype = self.data.dtype + device = self.device + match args, kwargs: + case ((dtype, *_args), {**_kwargs}) if isinstance(dtype, torch.dtype): + # overload 1 + ... + case (_args, {'dtype': dtype, **_kwargs}) if isinstance(dtype, torch.dtype): + # dtype as kwarg + ... + case ((other, *_args), {**_kwargs}) | (_args, {'other': other, **_kwargs}) if isinstance( + other, torch.Tensor + ): + # overload 3: use dtype and device from other + dtype = other.dtype + device = other.device + match args, kwargs: + case ((device, dtype, *_args), {**_kwargs}) if isinstance(device, torch.device | str) and isinstance( + dtype, torch.dtype + ): + # overload 2 with device and dtype + ... + case ((device, *_args), {**_kwargs}) if isinstance(device, torch.device | str): + # overload 2, only device + ... + case (_args, {'device': device, **_kwargs}) if isinstance(device, torch.device | str): + # device as kwarg + ... + + # The trajectory dtype will always be real, the data always be complex. + data = self.data.to(*_args, **{**_kwargs, 'dtype': dtype.to_complex(), 'device': device, 'copy': True}) + traj = self.traj.to(*_args, **{**_kwargs, 'dtype': dtype.to_real(), 'device': device}) + header = deepcopy(self.header) # TODO: use header.to + return type(self)(header=header, data=data, traj=traj) class _KDataProtocol(Protocol): From 41c660d989d8114c07ec08205d31fae71ef2b9be Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 11 May 2024 15:36:40 +0000 Subject: [PATCH 02/19] Update [ghstack-poisoned] --- src/mrpro/operators/_ConstraintsOp.py | 6 +- src/mrpro/operators/_EndomorphOperator.py | 223 ++++++++++++++++++++++ src/mrpro/operators/_MagnitudeOp.py | 6 +- src/mrpro/operators/_Operator.py | 10 +- src/mrpro/operators/_PhaseOp.py | 6 +- 5 files changed, 240 insertions(+), 11 deletions(-) create mode 100644 src/mrpro/operators/_EndomorphOperator.py diff --git a/src/mrpro/operators/_ConstraintsOp.py b/src/mrpro/operators/_ConstraintsOp.py index 192f8c6e2..ca839bd67 100644 --- a/src/mrpro/operators/_ConstraintsOp.py +++ b/src/mrpro/operators/_ConstraintsOp.py @@ -19,10 +19,11 @@ import torch import torch.nn.functional as F # noqa: N812 -from mrpro.operators._Operator import Operator +from mrpro.operators._EndomorphOperator import EndomorphOperator +from mrpro.operators._EndomorphOperator import endomorph -class ConstraintsOp(Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]): +class ConstraintsOp(EndomorphOperator): """Transformation to map real-valued tensors to certain ranges.""" def __init__( @@ -75,6 +76,7 @@ def softplus_inverse(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor: """Inverse of 'softplus_transformation.""" return beta * x + torch.log(-torch.expm1(-beta * x)) + @endomorph def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: """Transform tensors to chosen range. diff --git a/src/mrpro/operators/_EndomorphOperator.py b/src/mrpro/operators/_EndomorphOperator.py new file mode 100644 index 000000000..fa832192a --- /dev/null +++ b/src/mrpro/operators/_EndomorphOperator.py @@ -0,0 +1,223 @@ +"""Endomorph Operators.""" + +# Copyright 2024 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from typing import ParamSpec +from typing import Protocol +from typing import TypeAlias +from typing import TypeVar +from typing import TypeVarTuple +from typing import cast +from typing import overload + +import torch + +from mrpro.operators._Operator import Operator + +Tin = TypeVarTuple('Tin') +Tout = TypeVar('Tout', bound=tuple[torch.Tensor, ...], covariant=True) +P = ParamSpec('P') +Wrapped: TypeAlias = Callable[P, Tout] +F = TypeVar('F', bound=Wrapped) + + +class _EndomorphCallable(Protocol): + """A callable with the same number of tensor inputs and outputs. + + This is a protocol for a callable that takes a variadic number of tensor inputs + and returns the same number of tensor outputs. + + This is only implemented for up to 10 inputs, if more inputs are given, the return + will be a variadic number of tensors. + + This Protocol is used to decorate functions with the `endomorph` decorator. + """ + + @overload + def __call__(self, /) -> tuple[()]: ... + @overload + def __call__(self, x1: torch.Tensor, /) -> tuple[torch.Tensor]: ... + + @overload + def __call__(self, x1: torch.Tensor, x2: torch.Tensor, /) -> tuple[torch.Tensor, torch.Tensor]: ... + + @overload + def __call__( + self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, / + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @overload + def __call__( + self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, / + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @overload + def __call__( + self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, x5: torch.Tensor, / + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @overload + def __call__( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor, + x4: torch.Tensor, + x5: torch.Tensor, + x6: torch.Tensor, + /, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @overload + def __call__( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor, + x4: torch.Tensor, + x5: torch.Tensor, + x6: torch.Tensor, + x7: torch.Tensor, + /, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @overload + def __call__( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor, + x4: torch.Tensor, + x5: torch.Tensor, + x6: torch.Tensor, + x7: torch.Tensor, + x8: torch.Tensor, + /, + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + ]: ... + + @overload + def __call__( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor, + x4: torch.Tensor, + x5: torch.Tensor, + x6: torch.Tensor, + x7: torch.Tensor, + x8: torch.Tensor, + x9: torch.Tensor, + /, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: ... + + @overload + def __call__( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor, + x4: torch.Tensor, + x5: torch.Tensor, + x6: torch.Tensor, + x7: torch.Tensor, + x8: torch.Tensor, + x9: torch.Tensor, + x10: torch.Tensor, + /, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: ... + + @overload + def __call__( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor, + x4: torch.Tensor, + x5: torch.Tensor, + x6: torch.Tensor, + x7: torch.Tensor, + x8: torch.Tensor, + x9: torch.Tensor, + x10: torch.Tensor, + /, + *args: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + *tuple[torch.Tensor, ...], + ]: ... + + @overload + def __call__(self, /, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ... + + def __call__(self, /, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ... + + +def endomorph(f: F, /) -> _EndomorphCallable: + """Decorate a function to make it an endomorph callable.""" + return f + + +class EndomorphOperator(Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]): + """Endomorph Operator. + + Endomorph Operators have N tensor inputs and exactly N outputs. + """ + + @endomorph + def __call__(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + return super().__call__(*x) + + @endomorph + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + return x + + def __matmul__(self, other: Operator[*Tin, Tout]) -> Operator[*Tin, Tout]: + """Operator composition.""" + return cast(Operator[*Tin, Tout], super().__matmul__(other)) diff --git a/src/mrpro/operators/_MagnitudeOp.py b/src/mrpro/operators/_MagnitudeOp.py index 909dd409b..0ce9cc20f 100644 --- a/src/mrpro/operators/_MagnitudeOp.py +++ b/src/mrpro/operators/_MagnitudeOp.py @@ -16,12 +16,14 @@ import torch -from mrpro.operators._Operator import Operator +from mrpro.operators._EndomorphOperator import EndomorphOperator +from mrpro.operators._EndomorphOperator import endomorph -class MagnitudeOp(Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]): +class MagnitudeOp(EndomorphOperator): """Magnitude of input tensors.""" + @endomorph def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: """Magnitude of tensors. diff --git a/src/mrpro/operators/_Operator.py b/src/mrpro/operators/_Operator.py index dee4cc52f..ba41067a4 100644 --- a/src/mrpro/operators/_Operator.py +++ b/src/mrpro/operators/_Operator.py @@ -26,7 +26,7 @@ Tin = TypeVarTuple('Tin') # TODO: bind to torch.Tensors Tin2 = TypeVarTuple('Tin2') # TODO: bind to torch.Tensors -Tout = TypeVar('Tout', bound=tuple) # TODO: bind to torch.Tensor +Tout = TypeVar('Tout', bound=tuple, covariant=True) # TODO: bind to torch.Tensors class Operator(Generic[*Tin, Tout], ABC, torch.nn.Module): @@ -40,7 +40,7 @@ def forward(self, *args: *Tin) -> Tout: def __call__(self, *args: *Tin) -> Tout: return super().__call__(*args) - def __matmul__(self, other: Operator[*Tin2, tuple[*Tin]]) -> Operator[*Tin2, Tout]: + def __matmul__(self: Operator[*Tin, Tout], other: Operator[*Tin2, tuple[*Tin]]) -> Operator[*Tin2, Tout]: """Operator composition.""" return OperatorComposition(self, other) @@ -57,15 +57,15 @@ def __rmul__(self, other: torch.Tensor) -> Operator[*Tin, Tout]: # type: ignore return OperatorElementwiseProductRight(self, other) -class OperatorComposition(Operator[*Tin, Tout]): +class OperatorComposition(Operator[*Tin2, Tout]): """Operator composition.""" - def __init__(self, operator1: Operator[*Tin2, Tout], operator2: Operator[*Tin, tuple[*Tin2]]): + def __init__(self, operator1: Operator[*Tin, Tout], operator2: Operator[*Tin2, tuple[*Tin]]): super().__init__() self._operator1 = operator1 self._operator2 = operator2 - def forward(self, *args: *Tin) -> Tout: + def forward(self, *args: *Tin2) -> Tout: """Operator composition.""" return self._operator1(*self._operator2(*args)) diff --git a/src/mrpro/operators/_PhaseOp.py b/src/mrpro/operators/_PhaseOp.py index 88b71009f..c73a60585 100644 --- a/src/mrpro/operators/_PhaseOp.py +++ b/src/mrpro/operators/_PhaseOp.py @@ -16,12 +16,14 @@ import torch -from mrpro.operators._Operator import Operator +from mrpro.operators._EndomorphOperator import EndomorphOperator +from mrpro.operators._EndomorphOperator import endomorph -class PhaseOp(Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]): +class PhaseOp(EndomorphOperator): """Phase of input tensors.""" + @endomorph def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: """Phase of tensors. From a3fab1b9acbc408ad379e99fa1674248edb33422 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 11 May 2024 15:45:55 +0000 Subject: [PATCH 03/19] Update [ghstack-poisoned] --- pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9c7e50e6..8db6ac680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,9 +29,8 @@ classifiers = [ ] dependencies = [ "numpy>=1.23,<2.0", - "torch>=2.2,<3.0", - "ismrmrd", - "xsdata>=22.2,<23", + "torch>=2.3,<3.0", + "ismrmrd>=1.14.1,<2.0", "einops", "pydicom", "pypulseq@git+https://github.com/imr-framework/pypulseq", From 4918ef0a52d4b512403665a551f5c543d874b611 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 11 May 2024 22:33:54 +0000 Subject: [PATCH 04/19] Update [ghstack-poisoned] --- src/mrpro/algorithms/optimizers/_lbfgs.py | 6 ------ src/mrpro/data/_Data.py | 4 ++-- src/mrpro/data/_KNoise.py | 4 ++-- src/mrpro/data/_KTrajectory.py | 12 ++++++------ src/mrpro/data/_kdata/_KData.py | 4 ++-- 5 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/_lbfgs.py b/src/mrpro/algorithms/optimizers/_lbfgs.py index 0101c8a56..7b08a0979 100644 --- a/src/mrpro/algorithms/optimizers/_lbfgs.py +++ b/src/mrpro/algorithms/optimizers/_lbfgs.py @@ -66,12 +66,6 @@ def lbfgs( ------- list of optimized parameters """ - # TODO: remove after new pytorch release; - if torch.tensor([torch.is_complex(p) for p in initial_parameters]).any(): - raise ValueError( - "at least one tensor in 'params' is complex-valued; \ - \ncomplex-valued tensors will be allowed for lbfgs in future torch versions", - ) parameters = [p.detach().clone().requires_grad_(True) for p in initial_parameters] optim = LBFGS( diff --git a/src/mrpro/data/_Data.py b/src/mrpro/data/_Data.py index c1e9379fd..211692c2f 100644 --- a/src/mrpro/data/_Data.py +++ b/src/mrpro/data/_Data.py @@ -59,7 +59,7 @@ def cuda( """ return Data( header=self.header, - data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] + data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), ) def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Data: @@ -70,4 +70,4 @@ def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Dat memory_format The desired memory format of returned tensor. """ - return Data(header=self.header, data=self.data.cpu(memory_format=memory_format)) # type: ignore [call-arg] + return Data(header=self.header, data=self.data.cpu(memory_format=memory_format)) diff --git a/src/mrpro/data/_KNoise.py b/src/mrpro/data/_KNoise.py index 0afe7b485..d9912dce0 100644 --- a/src/mrpro/data/_KNoise.py +++ b/src/mrpro/data/_KNoise.py @@ -120,7 +120,7 @@ def cuda( The desired memory format of returned Tensor. """ return type(self)( - data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] + data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), ) def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Self: @@ -132,5 +132,5 @@ def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Sel The desired memory format of returned Tensor. """ return type(self)( - data=self.data.cpu(memory_format=memory_format), # type: ignore [call-arg] + data=self.data.cpu(memory_format=memory_format), ) diff --git a/src/mrpro/data/_KTrajectory.py b/src/mrpro/data/_KTrajectory.py index 061a47f53..27f837ab8 100644 --- a/src/mrpro/data/_KTrajectory.py +++ b/src/mrpro/data/_KTrajectory.py @@ -214,9 +214,9 @@ def cuda( The desired memory format of returned Tensor. """ return KTrajectory( - kz=self.kz.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] - ky=self.ky.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] - kx=self.kx.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] + kz=self.kz.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), + ky=self.ky.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), + kx=self.kx.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), ) def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> KTrajectory: @@ -228,7 +228,7 @@ def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> KTr The desired memory format of returned Tensor. """ return KTrajectory( - kz=self.kz.cpu(memory_format=memory_format), # type: ignore [call-arg] - ky=self.ky.cpu(memory_format=memory_format), # type: ignore [call-arg] - kx=self.kx.cpu(memory_format=memory_format), # type: ignore [call-arg] + kz=self.kz.cpu(memory_format=memory_format), + ky=self.ky.cpu(memory_format=memory_format), + kx=self.kx.cpu(memory_format=memory_format), ) diff --git a/src/mrpro/data/_kdata/_KData.py b/src/mrpro/data/_kdata/_KData.py index 80c82721c..5f39fa9e0 100644 --- a/src/mrpro/data/_kdata/_KData.py +++ b/src/mrpro/data/_kdata/_KData.py @@ -332,7 +332,7 @@ def cuda( """ return KData( header=self.header, - data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), # type: ignore [call-arg] + data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), traj=self.traj.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), ) @@ -346,7 +346,7 @@ def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> KDa """ return KData( header=self.header, - data=self.data.cpu(memory_format=memory_format), # type: ignore [call-arg] + data=self.data.cpu(memory_format=memory_format), traj=self.traj.cpu(memory_format=memory_format), ) From ff4d9cbdeb35e80743ed25224b3c2be003355f2b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 11 May 2024 22:41:21 +0000 Subject: [PATCH 05/19] Update [ghstack-poisoned] --- src/mrpro/algorithms/optimizers/_lbfgs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mrpro/algorithms/optimizers/_lbfgs.py b/src/mrpro/algorithms/optimizers/_lbfgs.py index 7b08a0979..5636aaad1 100644 --- a/src/mrpro/algorithms/optimizers/_lbfgs.py +++ b/src/mrpro/algorithms/optimizers/_lbfgs.py @@ -66,7 +66,6 @@ def lbfgs( ------- list of optimized parameters """ - parameters = [p.detach().clone().requires_grad_(True) for p in initial_parameters] optim = LBFGS( params=parameters, From f4be0f7c9db85b708f1b99b12ac49fdae5bbf101 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sun, 12 May 2024 12:44:55 +0000 Subject: [PATCH 06/19] Update [ghstack-poisoned] --- src/mrpro/operators/_EndomorphOperator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mrpro/operators/_EndomorphOperator.py b/src/mrpro/operators/_EndomorphOperator.py index fa832192a..1951fbcce 100644 --- a/src/mrpro/operators/_EndomorphOperator.py +++ b/src/mrpro/operators/_EndomorphOperator.py @@ -221,3 +221,7 @@ def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: def __matmul__(self, other: Operator[*Tin, Tout]) -> Operator[*Tin, Tout]: """Operator composition.""" return cast(Operator[*Tin, Tout], super().__matmul__(other)) + + def __rmatmul__(self, other: Operator[*Tin, Tout]) -> Operator[*Tin, Tout]: + """Operator composition.""" + return other.__matmul__(cast(Operator[*Tin, tuple[*Tin]], self)) From 80bae5de08d55ae554282e051a452aa00acd8660 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sun, 12 May 2024 12:50:57 +0000 Subject: [PATCH 07/19] Update [ghstack-poisoned] --- src/mrpro/data/_kdata/_KData.py | 29 ------ src/mrpro/data/_kdata/_KDataProtocol.py | 55 ++++++++++++ src/mrpro/data/_kdata/_KDataRearrangeMixin.py | 11 +-- src/mrpro/data/_kdata/_KDataRemoveOsMixin.py | 10 +-- src/mrpro/data/_kdata/_KDataSelectMixin.py | 13 ++- src/mrpro/data/_kdata/_KDataSplitMixin.py | 21 ++--- src/mrpro/phantoms/_EllipsePhantom.py | 4 +- src/mrpro/utils/_Rotation.py | 90 ++++++++++++------- tests/_RandomGenerator.py | 18 ++-- tests/algorithms/test_optimizers.py | 2 +- tests/data/_IsmrmrdRawTestData.py | 2 +- tests/data/test_csm_data.py | 2 +- tests/data/test_trajectory.py | 2 +- tests/operators/test_operators.py | 8 +- tests/utils/test_filters.py | 6 +- tests/utils/test_rotation.py | 79 ++++++++-------- 16 files changed, 196 insertions(+), 156 deletions(-) create mode 100644 src/mrpro/data/_kdata/_KDataProtocol.py diff --git a/src/mrpro/data/_kdata/_KData.py b/src/mrpro/data/_kdata/_KData.py index 2b6e23c8b..f671f31e5 100644 --- a/src/mrpro/data/_kdata/_KData.py +++ b/src/mrpro/data/_kdata/_KData.py @@ -22,7 +22,6 @@ from copy import deepcopy from pathlib import Path from typing import Any -from typing import Protocol from typing import Self from typing import overload @@ -338,31 +337,3 @@ def to(self, *args, **kwargs) -> Self: # noqa: D417 traj = self.traj.to(*_args, **{**_kwargs, 'dtype': dtype.to_real(), 'device': device}) header = deepcopy(self.header) # TODO: use header.to return type(self)(header=header, data=data, traj=traj) - - -class _KDataProtocol(Protocol): - """Protocol for KData used for type hinting in KData mixins. - - Note that the actual KData class can have more properties and methods than those defined here. - - If you want to use a property or method of KData in a new KDataMixin class, - you must add it to this Protocol to make sure that the type hinting works. - - For more information about Protocols see: - https://typing.readthedocs.io/en/latest/spec/protocol.html#protocols - """ - - @property - def header(self) -> KHeader: ... - - @property - def data(self) -> torch.Tensor: ... - - @property - def traj(self) -> KTrajectory: ... - - def __init__(self, header: KHeader, data: torch.Tensor, traj: KTrajectory): ... - - def _split_k2_or_k1_into_other( - self, split_idx: torch.Tensor, other_label: str, split_dir: str - ) -> _KDataProtocol: ... diff --git a/src/mrpro/data/_kdata/_KDataProtocol.py b/src/mrpro/data/_kdata/_KDataProtocol.py new file mode 100644 index 000000000..0dd5f112c --- /dev/null +++ b/src/mrpro/data/_kdata/_KDataProtocol.py @@ -0,0 +1,55 @@ +"""Protocol for KData.""" + +# Copyright 2024 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Literal +from typing import Protocol +from typing import Self + +import torch +from mrpro.data._KHeader import KHeader +from mrpro.data._KTrajectory import KTrajectory + + +class _KDataProtocol(Protocol): + """Protocol for KData used for type hinting in KData mixins. + + Note that the actual KData class can have more properties and methods than those defined here. + + If you want to use a property or method of KData in a new KDataMixin class, + you must add it to this Protocol to make sure that the type hinting works. + + For more information about Protocols see: + https://typing.readthedocs.io/en/latest/spec/protocol.html#protocols + """ + + @property + def header(self) -> KHeader: ... + + @property + def data(self) -> torch.Tensor: ... + + @property + def traj(self) -> KTrajectory: ... + + def __init__(self, header: KHeader, data: torch.Tensor, traj: KTrajectory): ... + + def _split_k2_or_k1_into_other( + self, + split_idx: torch.Tensor, + other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + split_dir: Literal['k1', 'k2'], + ) -> Self: ... diff --git a/src/mrpro/data/_kdata/_KDataRearrangeMixin.py b/src/mrpro/data/_kdata/_KDataRearrangeMixin.py index f07ad3dd0..85aa4d06d 100644 --- a/src/mrpro/data/_kdata/_KDataRearrangeMixin.py +++ b/src/mrpro/data/_kdata/_KDataRearrangeMixin.py @@ -16,21 +16,18 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import Self from einops import rearrange - -if TYPE_CHECKING: - from mrpro.data._kdata._KData import _KDataProtocol - from mrpro.data._AcqInfo import AcqInfo +from mrpro.data._kdata._KDataProtocol import _KDataProtocol from mrpro.utils import modify_acq_info -class KDataRearrangeMixin: +class KDataRearrangeMixin(_KDataProtocol): """Rearrange KData.""" - def rearrange_k2_k1_into_k1(self: _KDataProtocol) -> _KDataProtocol: + def rearrange_k2_k1_into_k1(self: Self) -> Self: """Rearrange kdata from (... k2 k1 ...) to (... 1 (k2 k1) ...). Parameters diff --git a/src/mrpro/data/_kdata/_KDataRemoveOsMixin.py b/src/mrpro/data/_kdata/_KDataRemoveOsMixin.py index 8ab220a79..d884d7505 100644 --- a/src/mrpro/data/_kdata/_KDataRemoveOsMixin.py +++ b/src/mrpro/data/_kdata/_KDataRemoveOsMixin.py @@ -14,19 +14,17 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING +from typing import Self import torch +from mrpro.data._kdata._KDataProtocol import _KDataProtocol from mrpro.data._KTrajectory import KTrajectory -if TYPE_CHECKING: - from mrpro.data._kdata._KData import _KDataProtocol - -class KDataRemoveOsMixin: +class KDataRemoveOsMixin(_KDataProtocol): """Remove oversampling along readout dimension.""" - def remove_readout_os(self: _KDataProtocol) -> _KDataProtocol: + def remove_readout_os(self: Self) -> Self: """Remove any oversampling along the readout (k0) direction. This function is inspired by https://github.com/gadgetron/gadgetron-python. diff --git a/src/mrpro/data/_kdata/_KDataSelectMixin.py b/src/mrpro/data/_kdata/_KDataSelectMixin.py index 98f8fe662..0a3a1b4fe 100644 --- a/src/mrpro/data/_kdata/_KDataSelectMixin.py +++ b/src/mrpro/data/_kdata/_KDataSelectMixin.py @@ -16,25 +16,22 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING from typing import Literal +from typing import Self import torch - -if TYPE_CHECKING: - from mrpro.data._kdata._KData import _KDataProtocol - +from mrpro.data._kdata._KDataProtocol import _KDataProtocol from mrpro.utils import modify_acq_info -class KDataSelectMixin: +class KDataSelectMixin(_KDataProtocol): """Select subset of KData.""" def select_other_subset( - self: _KDataProtocol, + self: Self, subset_idx: torch.Tensor, subset_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> _KDataProtocol: + ) -> Self: """Select a subset from the other dimension of KData. Parameters diff --git a/src/mrpro/data/_kdata/_KDataSplitMixin.py b/src/mrpro/data/_kdata/_KDataSplitMixin.py index cedc8426b..78ef062a1 100644 --- a/src/mrpro/data/_kdata/_KDataSplitMixin.py +++ b/src/mrpro/data/_kdata/_KDataSplitMixin.py @@ -17,29 +17,26 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING from typing import Literal +from typing import Self import torch from einops import rearrange from einops import repeat - -if TYPE_CHECKING: - from mrpro.data._kdata._KData import _KDataProtocol - from mrpro.data._EncodingLimits import Limits +from mrpro.data._kdata._KDataProtocol import _KDataProtocol from mrpro.utils import modify_acq_info -class KDataSplitMixin: +class KDataSplitMixin(_KDataProtocol): """Split KData into other subsets.""" def _split_k2_or_k1_into_other( - self: _KDataProtocol, + self, split_idx: torch.Tensor, other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], split_dir: Literal['k2', 'k1'], - ) -> _KDataProtocol: + ) -> Self: """Based on an index tensor, split the data in e.g. phases. Parameters @@ -132,10 +129,10 @@ def reshape_acq_info(info: torch.Tensor): return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) def split_k1_into_other( - self: _KDataProtocol, + self: Self, split_idx: torch.Tensor, other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> _KDataProtocol: + ) -> Self: """Based on an index tensor, split the data in e.g. phases. Parameters @@ -154,10 +151,10 @@ def split_k1_into_other( return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k1') def split_k2_into_other( - self: _KDataProtocol, + self: Self, split_idx: torch.Tensor, other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> _KDataProtocol: + ) -> Self: """Based on an index tensor, split the data in e.g. phases. Parameters diff --git a/src/mrpro/phantoms/_EllipsePhantom.py b/src/mrpro/phantoms/_EllipsePhantom.py index bb9004091..167e5142b 100644 --- a/src/mrpro/phantoms/_EllipsePhantom.py +++ b/src/mrpro/phantoms/_EllipsePhantom.py @@ -91,7 +91,7 @@ def image_space(self, image_dimensions: SpatialDimension[int]) -> torch.Tensor: """ # Calculate image representation of phantom ny, nx = image_dimensions.y, image_dimensions.x - ix, iy = torch.meshgrid( + ix, it = torch.meshgrid( torch.linspace(-nx // 2, nx // 2 - 1, nx), torch.linspace(-ny // 2, ny // 2 - 1, ny), indexing='xy', @@ -101,7 +101,7 @@ def image_space(self, image_dimensions: SpatialDimension[int]) -> torch.Tensor: for ellipse in self.ellipses: in_ellipse = ( (ix / nx - ellipse.center_x) ** 2 / ellipse.radius_x**2 - + (iy / ny - ellipse.center_y) ** 2 / ellipse.radius_y**2 + + (it / ny - ellipse.center_y) ** 2 / ellipse.radius_y**2 ) <= 1 idata += ellipse.intensity * in_ellipse diff --git a/src/mrpro/utils/_Rotation.py b/src/mrpro/utils/_Rotation.py index eded0ee1e..773860c03 100644 --- a/src/mrpro/utils/_Rotation.py +++ b/src/mrpro/utils/_Rotation.py @@ -46,6 +46,8 @@ import warnings from collections.abc import Sequence from typing import TYPE_CHECKING +from typing import Literal +from typing import overload import numpy as np import torch @@ -263,7 +265,7 @@ class Rotation(torch.nn.Module): - arbitrary number of batching dimensions """ - def __init__(self, quaternions: torch.Tensor, normalize: bool = True, copy: bool = True): + def __init__(self, quaternions: torch.Tensor | _NestedSequence[float], normalize: bool = True, copy: bool = True): """Initialize a new Rotation. Instead of calling this method, also consider the different @@ -280,33 +282,32 @@ def __init__(self, quaternions: torch.Tensor, normalize: bool = True, copy: bool the quaternions Parameter of this instance will be a view if the quaternions passed in. """ super().__init__() - if not isinstance(quaternions, torch.Tensor): - quaternions = torch.as_tensor(quaternions) # type: ignore[unreachable] - if torch.is_complex(quaternions): + quaternions_ = torch.as_tensor(quaternions) + if torch.is_complex(quaternions_): raise ValueError('quaternions should be real numbers') - if not torch.is_floating_point(quaternions): + if not torch.is_floating_point(quaternions_): # integer or boolean dtypes - quaternions = quaternions.float() - if quaternions.shape[-1] != 4: - raise ValueError('Expected `quaternions` to have shape (..., 4), ' f'got {quaternions.shape}.') + quaternions_ = quaternions_.float() + if quaternions_.shape[-1] != 4: + raise ValueError('Expected `quaternions` to have shape (..., 4), ' f'got {quaternions_.shape}.') # If a single quaternion is given, convert it to a 2D 1 x 4 matrix but # set self._single to True so that we can return appropriate objects # in the `to_...` methods - if quaternions.shape == (4,): - quaternions = quaternions[None, :] + if quaternions_.shape == (4,): + quaternions_ = quaternions_[None, :] self._single = True else: self._single = False if normalize: - norms = torch.linalg.vector_norm(quaternions, dim=-1, keepdim=True) + norms = torch.linalg.vector_norm(quaternions_, dim=-1, keepdim=True) if torch.any(torch.isclose(norms.float(), torch.tensor(0.0))): raise ValueError('Found zero norm quaternion in `quaternions`.') - quaternions = quaternions / norms + quaternions_ = quaternions_ / norms elif copy: - quaternions = quaternions.clone() - self._quaternions = torch.nn.Parameter(quaternions, quaternions.requires_grad) + quaternions_ = quaternions_.clone() + self._quaternions = torch.nn.Parameter(quaternions_, quaternions_.requires_grad) @property def single(self) -> bool: @@ -314,7 +315,7 @@ def single(self) -> bool: return self._single @classmethod - def from_quat(cls, quaternions: torch.Tensor | Sequence[float]) -> Rotation: + def from_quat(cls, quaternions: torch.Tensor | _NestedSequence[float]) -> Rotation: """Initialize from quaternions. 3D rotations can be represented using unit-norm quaternions [1]_. @@ -341,7 +342,7 @@ def from_quat(cls, quaternions: torch.Tensor | Sequence[float]) -> Rotation: return cls(quaternions, normalize=True) @classmethod - def from_matrix(cls, matrix: torch.Tensor) -> Rotation: + def from_matrix(cls, matrix: torch.Tensor | _NestedSequence[float]) -> Rotation: """Initialize from rotation matrix. Rotations in 3 dimensions can be represented with 3 x 3 proper @@ -367,8 +368,7 @@ def from_matrix(cls, matrix: torch.Tensor) -> Rotation: 440-442, 2008. """ if not isinstance(matrix, torch.Tensor): - # this should not happen if following type hints, but we are defensive - matrix = torch.as_tensor(matrix) # type: ignore[unreachable] + matrix = torch.as_tensor(matrix) if matrix.shape[-2:] != (3, 3): raise ValueError(f'Expected `matrix` to have shape (..., 3, 3), got {matrix.shape}') if torch.is_complex(matrix): @@ -381,7 +381,7 @@ def from_matrix(cls, matrix: torch.Tensor) -> Rotation: return cls(quaternions, normalize=True, copy=False) @classmethod - def from_rotvec(cls, rotvec: torch.Tensor | Sequence[float], degrees: bool = False) -> Rotation: + def from_rotvec(cls, rotvec: torch.Tensor | _NestedSequence[float], degrees: bool = False) -> Rotation: if not isinstance(rotvec, torch.Tensor): rotvec = torch.as_tensor(rotvec) if torch.is_complex(rotvec): @@ -401,7 +401,9 @@ def from_rotvec(cls, rotvec: torch.Tensor | Sequence[float], degrees: bool = Fal return cls(quaternions, normalize=False, copy=False) @classmethod - def from_euler(cls, seq: str, angles: torch.Tensor | Sequence[float], degrees: bool = False) -> Rotation: + def from_euler( + cls, seq: str, angles: torch.Tensor | _NestedSequence[float] | float, degrees: bool = False + ) -> Rotation: """Initialize from Euler angles. Rotations in 3-D can be represented by a sequence of 3 @@ -678,7 +680,7 @@ def concatenate(cls, rotations: Sequence[Rotation]) -> Rotation: def forward( self, - vectors: Sequence[float] | torch.Tensor | SpatialDimension[torch.Tensor] | SpatialDimension[float], + vectors: _NestedSequence[float] | torch.Tensor | SpatialDimension[torch.Tensor] | SpatialDimension[float], inverse: bool = False, ) -> torch.Tensor | SpatialDimension[torch.Tensor]: """Apply this rotation to a set of vectors. @@ -994,7 +996,7 @@ def quaternion_x(self) -> torch.Tensor: return self._quaternions[..., axis] @quaternion_x.setter - def quaternion_x(self, quat_x: torch.Tensor): + def quaternion_x(self, quat_x: torch.Tensor | float): """Set x component of the quaternion.""" axis = QUAT_AXIS_ORDER.index('x') self._quaternions[..., axis] = quat_x @@ -1008,7 +1010,7 @@ def quaternion_y(self) -> torch.Tensor: return self._quaternions[..., axis] @quaternion_y.setter - def quaternion_y(self, quat_y: torch.Tensor): + def quaternion_y(self, quat_y: torch.Tensor | float): """Set y component of the quaternion.""" axis = QUAT_AXIS_ORDER.index('y') self._quaternions[..., axis] = quat_y @@ -1022,7 +1024,7 @@ def quaternion_z(self) -> torch.Tensor: return self._quaternions[..., axis] @quaternion_z.setter - def quaternion_z(self, quat_z: torch.Tensor): + def quaternion_z(self, quat_z: torch.Tensor | float): """Set z component of the quaternion.""" axis = QUAT_AXIS_ORDER.index('z') self._quaternions[..., axis] = quat_z @@ -1036,7 +1038,7 @@ def quaternion_w(self) -> torch.Tensor: return self._quaternions[..., axis] @quaternion_w.setter - def quaternion_w(self, quat_w: torch.Tensor): + def quaternion_w(self, quat_w: torch.Tensor | float): """Set w component of the quaternion.""" axis = QUAT_AXIS_ORDER.index('w') self._quaternions[..., axis] = quat_w @@ -1095,12 +1097,35 @@ def identity(cls, shape: int | None | tuple[int, ...] = None) -> Rotation: q[..., -1] = 1 return cls(q, normalize=False) + @overload + @classmethod + def align_vectors( + cls, + a: torch.Tensor | Sequence[torch.Tensor] | Sequence[float] | Sequence[Sequence[float]], + b: torch.Tensor | Sequence[torch.Tensor] | Sequence[float] | Sequence[Sequence[float]], + weights: torch.Tensor | Sequence[float] | Sequence[Sequence[float]] | None = None, + *, + return_sensitivity: Literal[False] = False, + ) -> tuple[Rotation, float]: ... + + @overload + @classmethod + def align_vectors( + cls, + a: torch.Tensor | Sequence[torch.Tensor] | Sequence[float] | Sequence[Sequence[float]], + b: torch.Tensor | Sequence[torch.Tensor] | Sequence[float] | Sequence[Sequence[float]], + weights: torch.Tensor | Sequence[float] | Sequence[Sequence[float]] | None = None, + *, + return_sensitivity: Literal[True], + ) -> tuple[Rotation, float, torch.Tensor]: ... + @classmethod def align_vectors( cls, - a: torch.Tensor | Sequence[torch.Tensor], - b: torch.Tensor | Sequence[torch.Tensor], - weights: torch.Tensor | None = None, + a: torch.Tensor | Sequence[torch.Tensor] | Sequence[float] | Sequence[Sequence[float]], + b: torch.Tensor | Sequence[torch.Tensor] | Sequence[float] | Sequence[Sequence[float]], + weights: torch.Tensor | Sequence[float] | Sequence[Sequence[float]] | None = None, + *, return_sensitivity: bool = False, ) -> tuple[Rotation, float] | tuple[Rotation, float, torch.Tensor]: """Estimate a rotation to optimally align two sets of vectors. @@ -1124,7 +1149,7 @@ def align_vectors( elif isinstance(weights, torch.Tensor): weights_np = weights.numpy(force=True) else: - weights_np = np.asarray(weights) # type: ignore[unreachable] + weights_np = np.asarray(weights) if return_sensitivity: rotation_sp, rssd, sensitivity_np = Rotation_scipy.align_vectors(a_np, b_np, weights_np, True) @@ -1168,7 +1193,12 @@ def __repr__(self): else: return f'{tuple(self.shape)}-Batched Rotation()' - def mean(self, weights: torch.Tensor | None = None, dim: None | int | Sequence[int] = None, keepdim: bool = False): + def mean( + self, + weights: torch.Tensor | _NestedSequence[float] | None = None, + dim: None | int | Sequence[int] = None, + keepdim: bool = False, + ): r"""Get the mean of the rotations. The mean used is the chordal L2 mean (also called the projected or diff --git a/tests/_RandomGenerator.py b/tests/_RandomGenerator.py index 006d29630..c64931ac8 100644 --- a/tests/_RandomGenerator.py +++ b/tests/_RandomGenerator.py @@ -63,36 +63,36 @@ def _rand(self, size, low, high, dtype=torch.float32): low, high = self._clip_bounds(low, high, *self._dtype_bounds(dtype)) return (torch.rand(size, generator=self.generator, dtype=dtype) * (high - low)) + low - def float32_tensor(self, size: Sequence[int] = (1,), low: float = 0.0, high: float = 1.0): + def float32_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0): return self._rand(size, low, high, torch.float32) - def float64_tensor(self, size: Sequence[int] = (1,), low: float = 0.0, high: float = 1.0): + def float64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0): return self._rand(size, low, high, torch.float64) - def complex64_tensor(self, size: Sequence[int] = (1,), low: float = 0.0, high: float = 1.0): + def complex64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0): if low < 0: raise ValueError('low/high refer to the amplitude and must be positive') amp = self.float32_tensor(size, low, high) phase = self.float32_tensor(size, -torch.pi, torch.pi) return (amp * torch.exp(1j * phase)).to(dtype=torch.complex64) - def complex128_tensor(self, size: Sequence[int] = (1,), low: float = 0.0, high: float = 1.0): + def complex128_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0): if low < 0: raise ValueError('low/high refer to the amplitude and must be positive') amp = self.float64_tensor(size, low, high) phase = self.float64_tensor(size, -torch.pi, torch.pi) return (amp * torch.exp(1j * phase)).to(dtype=torch.complex128) - def int8_tensor(self, size: Sequence[int] = (1,), low: int = -1 << 7, high: int = 1 << 7): + def int8_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 7, high: int = 1 << 7): return self._randint(size, low, high, dtype=torch.int8) - def int16_tensor(self, size: Sequence[int] = (1,), low: int = -1 << 15, high: int = 1 << 15): + def int16_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 15, high: int = 1 << 15): return self._randint(size, low, high, dtype=torch.int16) - def int32_tensor(self, size: Sequence[int] = (1,), low: int = -1 << 31, high: int = 1 << 31): + def int32_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 31, high: int = 1 << 31): return self._randint(size, low, high, dtype=torch.int32) - def int64_tensor(self, size: Sequence[int] = (1,), low: int = -1 << 63, high: int = 1 << 63): + def int64_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 63, high: int = 1 << 63): return self._randint(size, low, high, dtype=torch.int64) # There is no uint32 in pytorch yet @@ -103,7 +103,7 @@ def int64_tensor(self, size: Sequence[int] = (1,), low: int = -1 << 63, high: in # def uint64_tensor(self, size: Sequence[int] = (1,), low: int = 0, high: int = 1 << 64): # return self._randint(size, low, high, dtype=torch.uint64) # noqa: ERA001 - def uint8_tensor(self, size: Sequence[int] = (1,), low: int = 0, high: int = 1 << 8): + def uint8_tensor(self, size: Sequence[int] | int = (1,), low: int = 0, high: int = 1 << 8): return self._randint(size, low, high, dtype=torch.uint8) def bool(self) -> bool: diff --git a/tests/algorithms/test_optimizers.py b/tests/algorithms/test_optimizers.py index 92501518b..b3bef8705 100644 --- a/tests/algorithms/test_optimizers.py +++ b/tests/algorithms/test_optimizers.py @@ -39,7 +39,7 @@ def test_optimizers_rosenbrock(optimizer, enforce_bounds_on_x1, optimizer_kwargs # save to compare with later as optimization should not change the initial points params_init_before = [i.detach().clone() for i in params_init] - params_init_grad_before = [i.grad.detach().clone() for i in params_init] + params_init_grad_before = [i.grad.detach().clone() if i.grad is not None else None for i in params_init] if enforce_bounds_on_x1: # the analytical solution for x_1 will be a, thus we can limit it into [0,2a] diff --git a/tests/data/_IsmrmrdRawTestData.py b/tests/data/_IsmrmrdRawTestData.py index f9eadfdd4..fdcf06a45 100644 --- a/tests/data/_IsmrmrdRawTestData.py +++ b/tests/data/_IsmrmrdRawTestData.py @@ -222,7 +222,7 @@ def __init__( dataset.write_xml_header(header.toXML('utf-8')) - # Create an acquistion and reuse it + # Create an acquisition and reuse it acq = ismrmrd.Acquisition() acq.resize(n_freq_encoding, self.n_coils, trajectory_dimensions=2) acq.version = 1 diff --git a/tests/data/test_csm_data.py b/tests/data/test_csm_data.py index 7a544473c..5300f2397 100644 --- a/tests/data/test_csm_data.py +++ b/tests/data/test_csm_data.py @@ -43,7 +43,7 @@ def test_CsmData_is_frozen_dataclass(random_test_data, random_kheader): """CsmData inherits frozen dataclass property from QData.""" csm = CsmData(data=random_test_data, header=random_kheader) with pytest.raises(dataclasses.FrozenInstanceError): - csm.data = random_test_data + csm.data = random_test_data # type: ignore[misc] def test_CsmData_iterative_Walsh(ellipse_phantom, random_kheader): diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 2eb5f954b..70816505d 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -50,7 +50,7 @@ def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): elif spacing == 'z': k = torch.zeros(nk) k_list.append(k) - trajectory = KTrajectory(*k_list, repeat_detection_tolerance=None) + trajectory = KTrajectory(k_list[0], k_list[1], k_list[2], repeat_detection_tolerance=None) return trajectory diff --git a/tests/operators/test_operators.py b/tests/operators/test_operators.py index 1c2adf734..ef4682e05 100644 --- a/tests/operators/test_operators.py +++ b/tests/operators/test_operators.py @@ -1,5 +1,7 @@ """Tests for the operators module.""" +from typing import cast + import pytest import torch from mrpro.operators import LinearOperator @@ -9,14 +11,14 @@ from tests.helper import dotproduct_adjointness_test -class DummyOperator(Operator[torch.Tensor, torch.Tensor]): +class DummyOperator(Operator[torch.Tensor, tuple[torch.Tensor,]]): """Dummy operator for testing, raises input to the power of value and sums.""" def __init__(self, value: torch.Tensor): super().__init__() self._value = value - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """Dummy operator.""" return ((x**self._value).sum().unsqueeze(0),) @@ -151,7 +153,7 @@ def test_elementwise_product_operator(value): def test_elementwise_rproduct_operator(value): a = DummyOperator(torch.tensor(2.0)) b = torch.tensor(value) - c = b * a + c = cast(DummyOperator, b * a) x = RandomGenerator(0).complex64_tensor(10) (y1,) = c(x) y2 = b * a(x)[0] diff --git a/tests/utils/test_filters.py b/tests/utils/test_filters.py index 59836ddd2..7dde39c50 100644 --- a/tests/utils/test_filters.py +++ b/tests/utils/test_filters.py @@ -48,7 +48,7 @@ def test_spatial_uniform_filter_wrong_width(data): """Test spatial_uniform_filter_3d with wrong width.""" with pytest.raises(ValueError, match='Invalid filter width'): - uniform_filter_3d(data, (3, 3)) + uniform_filter_3d(data, (3, 3)) # type: ignore[arg-type] def test_gaussian_filter_int_axis(data): @@ -134,8 +134,8 @@ def test_uniform_invalid_width(data): with pytest.raises(ValueError, match='positive'): uniform_filter(data, width=torch.tensor(-1.0)) with pytest.raises(ValueError, match='positive'): - uniform_filter(data, width=torch.nan) + uniform_filter(data, width=torch.nan) # type: ignore[arg-type] with pytest.warns(UserWarning, match='odd'): uniform_filter(data, width=2) with pytest.raises(ValueError, match='length'): - uniform_filter(data, width=(3.0, 3.0)) + uniform_filter(data, width=(3.0, 3.0)) # type: ignore[arg-type] diff --git a/tests/utils/test_rotation.py b/tests/utils/test_rotation.py index a83794133..407365bfd 100644 --- a/tests/utils/test_rotation.py +++ b/tests/utils/test_rotation.py @@ -253,7 +253,7 @@ def test_from_2d_single_rotvec(): def test_from_generic_rotvec(): - rotvec = [[1, 2, 2], [1, -1, 0.5], [0, 0, 0]] + rotvec = [[1.0, 2.0, 2.0], [1.0, -1.0, 0.5], [0.0, 0.0, 0.0]] expected_quat = torch.tensor( [[0.3324983, 0.6649967, 0.6649967, 0.0707372], [0.4544258, -0.4544258, 0.2272129, 0.7316889], [0, 0, 0, 1]] ) @@ -527,23 +527,24 @@ def test_from_euler_extrinsic_rotation_202(): ) +def _test_stats(error: torch.Tensor, mean_max: float, rms_max: float) -> None: + # helper function for mean error tests + mean = torch.mean(error, dim=0) + std = torch.std(error, dim=0) + rms = torch.hypot(mean, std) + assert torch.all(torch.abs(mean) < mean_max) + assert torch.all(rms < rms_max) + + @pytest.mark.parametrize('seq_tuple', permutations('xyz')) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): - # helper function for mean error tests - def test_stats(error, mean_max, rms_max): - mean = torch.mean(error, axis=0) - std = torch.std(error, axis=0) - rms = torch.hypot(mean, std) - assert torch.all(torch.abs(mean) < mean_max) - assert torch.all(rms < rms_max) - - rnd = np.random.RandomState(0) + rnd = RandomGenerator(0) n = 1000 - angles = np.empty((n, 3)) - angles[:, 0] = rnd.uniform(low=-torch.pi, high=torch.pi, size=(n,)) - angles[:, 1] = rnd.uniform(low=-torch.pi / 2, high=torch.pi / 2, size=(n,)) - angles[:, 2] = rnd.uniform(low=-torch.pi, high=torch.pi, size=(n,)) + angles = torch.empty((n, 3), dtype=torch.float64) + angles[:, 0] = rnd.float64_tensor(low=-torch.pi, high=torch.pi, size=(n,)) + angles[:, 1] = rnd.float64_tensor(low=-torch.pi / 2, high=torch.pi / 2, size=(n,)) + angles[:, 2] = rnd.float64_tensor(low=-torch.pi, high=torch.pi, size=(n,)) seq = ''.join(seq_tuple) if intrinsic: # Extrinsic rotation (wrt to global world) at lower case @@ -551,27 +552,19 @@ def test_stats(error, mean_max, rms_max): seq = seq.upper() rotation = Rotation.from_euler(seq, angles) angles_quat = rotation.as_euler(seq) - torch.testing.assert_close(torch.as_tensor(angles), angles_quat, atol=0, rtol=1e-12) - test_stats(angles_quat - angles, 1e-15, 1e-14) + torch.testing.assert_close(angles, angles_quat, atol=0, rtol=1e-11) + _test_stats(angles_quat - angles, 1e-15, 1e-14) @pytest.mark.parametrize('seq_tuple', permutations('xyz')) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_symmetric_axes(seq_tuple, intrinsic): - # helper function for mean error tests - def test_stats(error, mean_max, rms_max): - mean = torch.mean(error, axis=0) - std = torch.std(error, axis=0) - rms = torch.hypot(mean, std) - assert torch.all(torch.abs(mean) < mean_max) - assert torch.all(rms < rms_max) - - rnd = np.random.RandomState(0) + rnd = RandomGenerator(0) n = 1000 - angles = np.empty((n, 3)) - angles[:, 0] = rnd.uniform(low=-torch.pi, high=torch.pi, size=(n,)) - angles[:, 1] = rnd.uniform(low=0, high=torch.pi, size=(n,)) - angles[:, 2] = rnd.uniform(low=-torch.pi, high=torch.pi, size=(n,)) + angles = torch.empty((n, 3), dtype=torch.float64) + angles[:, 0] = rnd.float64_tensor(low=-torch.pi, high=torch.pi, size=(n,)) + angles[:, 1] = rnd.float64_tensor(low=0, high=torch.pi, size=(n,)) + angles[:, 2] = rnd.float64_tensor(low=-torch.pi, high=torch.pi, size=(n,)) # Rotation of the form A/B/A are rotation around symmetric axes seq = ''.join([seq_tuple[0], seq_tuple[1], seq_tuple[0]]) @@ -580,8 +573,8 @@ def test_stats(error, mean_max, rms_max): rotation = Rotation.from_euler(seq, angles) angles_quat = rotation.as_euler(seq) - torch.testing.assert_close(torch.as_tensor(angles), angles_quat, atol=0, rtol=1e-13) - test_stats(angles_quat - angles, 1e-16, 1e-14) + torch.testing.assert_close(angles, angles_quat, atol=0, rtol=1e-11) + _test_stats(angles_quat - angles, 1e-16, 1e-14) @pytest.mark.parametrize('seq_tuple', permutations('xyz')) @@ -922,7 +915,7 @@ def test_align_vectors_no_rotation(): def test_align_vectors_no_noise(): rnd = np.random.RandomState(0) c = Rotation.random(random_state=rnd) - b = rnd.normal(size=(5, 3)) + b = torch.tensor(rnd.normal(size=(5, 3))) a = c(b) est, rssd = Rotation.align_vectors(a, b) @@ -944,7 +937,7 @@ def test_align_vectors_rssd_sensitivity(): rssd_expected = 0.141421356237308 sens_expected = torch.tensor([[0.2, 0.0, 0.0], [0.0, 1.5, 1.0], [0.0, 1.0, 1.0]]) a = [[0, 1, 0], [0, 1, 1], [0, 1, 1]] - b = [[1, 0, 0], [1, 1.1, 0], [1, 0.9, 0]] + b = [[1.0, 0.0, 0.0], [1.0, 1.1, 0.0], [1.0, 0.9, 0.0]] rot, rssd, sens = Rotation.align_vectors(a, b, return_sensitivity=True) assert math.isclose(rssd, rssd_expected, abs_tol=1e-6, rel_tol=1e-4) assert torch.allclose(sens, sens_expected, atol=1e-6, rtol=1e-4) @@ -956,8 +949,8 @@ def test_align_vectors_scaled_weights(): b = Rotation.random(n, random_state=1)([1, 0, 0]) scale = 2 - est1, rssd1, cov1 = Rotation.align_vectors(a, b, torch.ones(n), True) - est2, rssd2, cov2 = Rotation.align_vectors(a, b, scale * torch.ones(n), True) + est1, rssd1, cov1 = Rotation.align_vectors(a, b, torch.ones(n), return_sensitivity=True) + est2, rssd2, cov2 = Rotation.align_vectors(a, b, scale * torch.ones(n), return_sensitivity=True) torch.testing.assert_close(est1.as_matrix(), est2.as_matrix()) torch.testing.assert_close(sqrt(scale) * rssd1, rssd2, atol=1e-6, rtol=1e-4) @@ -968,13 +961,13 @@ def test_align_vectors_noise(): rnd = np.random.RandomState(0) n_vectors = 100 rot = Rotation.random(random_state=rnd) - vectors = rnd.normal(size=(n_vectors, 3)).astype(np.float32) + vectors = torch.tensor(rnd.normal(size=(n_vectors, 3)), dtype=torch.float32) result = rot(vectors) # The paper adds noise as independently distributed angular errors sigma = np.deg2rad(1) tolerance = 1.5 * sigma - noise = Rotation.from_rotvec(rnd.normal(size=(n_vectors, 3), scale=sigma).astype(np.float32)) + noise = Rotation.from_rotvec(torch.tensor(rnd.normal(size=(n_vectors, 3), scale=sigma), dtype=torch.float32)) # Attitude errors must preserve norm. Hence apply individual random # rotations to each vector. @@ -1253,7 +1246,7 @@ def test_concatenate(): def test_concatenate_wrong_type(): """Test concatenation with non-Rotation objects""" with pytest.raises(TypeError, match='Rotation objects only'): - Rotation.concatenate([Rotation.identity(), 1, None]) + Rotation.concatenate([Rotation.identity(), 1, None]) # type: ignore[list-item] def test_len_and_bool(): @@ -1335,9 +1328,9 @@ def test_weighted_mean_dims(shape, keepdim, dim, expected_shape): def test_mean_invalid_weights(): """Test mean with invalid weights""" - r = Rotation.from_quat(np.eye(4)) + r = Rotation.from_quat(torch.eye(4)) with pytest.raises(ValueError, match='non-negative'): - r.mean(weights=-np.ones(4)) + r.mean(weights=-torch.ones(4)) def test_repr(): @@ -1355,10 +1348,10 @@ def test_quaternion_properties_single(): assert r.quaternion_y == quat[AXIS_ORDER.index('y')] assert r.quaternion_z == quat[AXIS_ORDER.index('z')] assert r.quaternion_w == quat[-1] - r.quaternion_x = 1.0 + r.quaternion_x = 1.0 # type: ignore[assignment] r.quaternion_y = torch.tensor(2.0) - r.quaternion_z = 3 - r.quaternion_w = 4.0 + r.quaternion_z = 3 # type: ignore[assignment] + r.quaternion_w = 4.0 # type: ignore[assignment] torch.testing.assert_close(r.quaternion_x, torch.tensor(1.0)) torch.testing.assert_close(r.quaternion_y, torch.tensor(2.0)) torch.testing.assert_close(r.quaternion_z, torch.tensor(3.0)) From cc26c960aed97d2b25ae24a34dab9e35d9fd6237 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sun, 12 May 2024 12:51:00 +0000 Subject: [PATCH 08/19] Update [ghstack-poisoned] --- .pre-commit-config.yaml | 19 +++++++++++++------ pyproject.toml | 3 ++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a6a561a0..08dbac3a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-docstring-first - id: check-merge-conflict @@ -14,25 +14,32 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.4 + rev: v0.4.4 hooks: - id: ruff # linter args: [--fix] - id: ruff-format # formatter - repo: https://github.com/crate-ci/typos - rev: v1.19.0 + rev: v1.21.0 hooks: - id: typos - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 + rev: v1.10.0 hooks: - id: mypy + pass_filenames: false + always_run: true + args: [src,tests,examples] additional_dependencies: - numpy - - torch>=2.2.0 + - torch>=2.3.0 - types-requests + - einops + - pydicom + - matplotlib + - pytest + - xsdata - "--index-url=https://download.pytorch.org/whl/cpu" - "--extra-index-url=https://pypi.python.org/simple" - exclude: docs/.*|tests/.* diff --git a/pyproject.toml b/pyproject.toml index 8db6ac680..5ea8f4f32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,8 +78,9 @@ exclude = ["docs"] enable_error_code = ["ignore-without-code"] warn_unused_ignores = true + [[tool.mypy.overrides]] -module = ["ismrmrd.*", "h5py", "scipy.*"] +module = ["ismrmrd.*", "h5py", "scipy.*", "torchkbnufft", "pypulseq", "zenodo_get"] ignore_missing_imports = true [tool.ruff] From 76c28d7d6a40da4d290c038732ea7d3e6bc9b2cf Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 13 May 2024 14:29:17 +0000 Subject: [PATCH 09/19] Update [ghstack-poisoned] --- src/mrpro/data/_AcqInfo.py | 3 +- src/mrpro/data/_IHeader.py | 3 +- src/mrpro/data/_KHeader.py | 3 +- src/mrpro/data/_KNoise.py | 73 +---------------- src/mrpro/data/_KTrajectory.py | 48 ----------- src/mrpro/data/_KTrajectoryRawShape.py | 3 +- src/mrpro/data/_MoveDataMixin.py | 103 +++++++++++++++++++++-- src/mrpro/data/_QHeader.py | 3 +- src/mrpro/data/_SpatialDimension.py | 4 +- src/mrpro/data/_kdata/_KData.py | 109 ------------------------- 10 files changed, 110 insertions(+), 242 deletions(-) diff --git a/src/mrpro/data/_AcqInfo.py b/src/mrpro/data/_AcqInfo.py index 82e241e56..a29e0c727 100644 --- a/src/mrpro/data/_AcqInfo.py +++ b/src/mrpro/data/_AcqInfo.py @@ -23,11 +23,12 @@ import numpy as np import torch +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data._SpatialDimension import SpatialDimension @dataclass(slots=True) -class AcqIdx: +class AcqIdx(MoveDataMixin): """Acquisition index for each readout.""" k1: torch.Tensor diff --git a/src/mrpro/data/_IHeader.py b/src/mrpro/data/_IHeader.py index 9a3861a8c..2dc07a518 100644 --- a/src/mrpro/data/_IHeader.py +++ b/src/mrpro/data/_IHeader.py @@ -26,13 +26,14 @@ from pydicom.tag import TagType from mrpro.data._KHeader import KHeader +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data._SpatialDimension import SpatialDimension MISC_TAGS = {'TimeAfterStart': 0x00191016} @dataclass(slots=True) -class IHeader: +class IHeader(MoveDataMixin): """MR image data header.""" # ToDo: decide which attributes to store in the header diff --git a/src/mrpro/data/_KHeader.py b/src/mrpro/data/_KHeader.py index 903612b8d..34e921abf 100644 --- a/src/mrpro/data/_KHeader.py +++ b/src/mrpro/data/_KHeader.py @@ -28,6 +28,7 @@ from mrpro.data import enums from mrpro.data._AcqInfo import AcqInfo from mrpro.data._EncodingLimits import EncodingLimits +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data._SpatialDimension import SpatialDimension from mrpro.data._TrajectoryDescription import TrajectoryDescription @@ -39,7 +40,7 @@ @dataclass(slots=True) -class KHeader: +class KHeader(MoveDataMixin): """MR raw data header. All information that is not covered by the dataclass is stored in diff --git a/src/mrpro/data/_KNoise.py b/src/mrpro/data/_KNoise.py index 0dad8d2c6..85844a8c6 100644 --- a/src/mrpro/data/_KNoise.py +++ b/src/mrpro/data/_KNoise.py @@ -18,17 +18,17 @@ import dataclasses from pathlib import Path -from typing import Self import ismrmrd import torch from einops import rearrange +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data.enums import AcqFlags @dataclasses.dataclass(slots=True, frozen=True) -class KNoise: +class KNoise(MoveDataMixin): """MR raw data / k-space data class for noise measurements. Attributes @@ -65,72 +65,3 @@ def from_file(cls, filename: str | Path, dataset_idx: int = -1) -> KNoise: noise_data = rearrange(noise_data, 'other coils (k2 k1 k0)->other coils k2 k1 k0', k1=1, k2=1) return cls(noise_data) - - def to( - self, - device: torch.device | str | None = None, - dtype: None | torch.dtype = None, - non_blocking: bool = False, - copy: bool = False, - memory_format: torch.memory_format = torch.preserve_format, - ) -> Self: - """Perform dtype and/or device conversion of data. - - Parameters - ---------- - device - The destination device. Defaults to the current device. - dtype - Dtype of the k-space data, can only be torch.complex64 or torch.complex128. - The dtype of the trajectory (torch.float32 or torch.float64) is then inferred from this. - non_blocking - If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. - Otherwise, the argument has no effect. - copy - If True a new Tensor is created even when the Tensor already matches the desired conversion. - memory_format - The desired memory format of returned Tensor. - """ - return type(self)( - data=self.data.to( - device=device, - dtype=dtype, - non_blocking=non_blocking, - copy=copy, - memory_format=memory_format, - ) - ) - - def cuda( - self, - device: torch.device | None = None, - non_blocking: bool = False, - memory_format: torch.memory_format = torch.preserve_format, - ) -> Self: - """Create copy of object with data in CUDA memory. - - Parameters - ---------- - device - The destination GPU device. Defaults to the current CUDA device. - non_blocking - If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. - Otherwise, the argument has no effect. - memory_format - The desired memory format of returned Tensor. - """ - return type(self)( - data=self.data.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format), - ) - - def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Self: - """Create copy of object in CPU memory. - - Parameters - ---------- - memory_format - The desired memory format of returned Tensor. - """ - return type(self)( - data=self.data.cpu(memory_format=memory_format), - ) diff --git a/src/mrpro/data/_KTrajectory.py b/src/mrpro/data/_KTrajectory.py index d353d325c..08be3c3e5 100644 --- a/src/mrpro/data/_KTrajectory.py +++ b/src/mrpro/data/_KTrajectory.py @@ -17,8 +17,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Self -from typing import overload import numpy as np import torch @@ -184,49 +182,3 @@ def as_tensor(self, stack_dim: int = 0) -> torch.Tensor: """ shape = self.broadcasted_shape return torch.stack([traj.expand(*shape) for traj in (self.kz, self.ky, self.kx)], dim=stack_dim) - - @overload - def to( - self, dtype: torch.dtype, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None - ) -> Self: ... - - @overload - def to( - self, - device: str | torch.device | int | None = None, - dtype: torch.dtype | None = None, - non_blocking: bool = False, - *, - memory_format: torch.memory_format | None = None, - ) -> Self: ... - - @overload - def to( - self, other: torch.Tensor, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None - ) -> Self: ... - - def to(self, *args, **kwargs) -> Self: - """Perform dtype and/or device conversion of trajectory. - - This will always return a new KTrajectory object. - - A torch.dtype and torch.device are inferred from the arguments - of self.to(*args, **kwargs). Please have a look at the - documentation of torch.Tensor.to() for more details. - """ - kwargs['copy'] = True - return type(self)( - kz=self.kz.to(*args, **kwargs), - ky=self.ky.to(*args, **kwargs), - kx=self.kx.to(*args, **kwargs), - ) - - @property - def device(self) -> torch.device: - """Return the device of the trajectory.""" - device_x = self.kx.device - device_y = self.ky.device - device_z = self.kz.device - if device_x != device_y or device_x != device_z: - raise ValueError('Trajectory is on different devices.') - return device_x diff --git a/src/mrpro/data/_KTrajectoryRawShape.py b/src/mrpro/data/_KTrajectoryRawShape.py index c6fe574fa..0566c580f 100644 --- a/src/mrpro/data/_KTrajectoryRawShape.py +++ b/src/mrpro/data/_KTrajectoryRawShape.py @@ -23,10 +23,11 @@ from einops import rearrange from mrpro.data._KTrajectory import KTrajectory +from mrpro.data._MoveDataMixin import MoveDataMixin @dataclass(slots=True, frozen=True) -class KTrajectoryRawShape: +class KTrajectoryRawShape(MoveDataMixin): """K-space trajectory shaped ((other*k2*k1),k0). Order of directions is always kz, ky, kx diff --git a/src/mrpro/data/_MoveDataMixin.py b/src/mrpro/data/_MoveDataMixin.py index 6ac2bd998..22a71ab06 100644 --- a/src/mrpro/data/_MoveDataMixin.py +++ b/src/mrpro/data/_MoveDataMixin.py @@ -2,6 +2,7 @@ import dataclasses from abc import ABC +from collections.abc import Sequence from copy import deepcopy from typing import Any from typing import ClassVar @@ -12,6 +13,11 @@ import torch +class InconsistentDeviceError(ValueError): + def __init__(self, *devices): + super().__init__(f'Inconsistent devices found, found at least {", ".join(devices)}') + + class DataclassInstance(Protocol): """An instance of a dataclass.""" @@ -21,8 +27,6 @@ class DataclassInstance(Protocol): class MoveDataMixin(ABC, DataclassInstance): """Move dataclass fields to cpu/gpu and convert dtypes.""" - data: torch.Tensor - @overload def to( self, dtype: torch.dtype, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None @@ -46,19 +50,73 @@ def to( def to(self, *args, **kwargs) -> Self: """Perform dtype and/or device conversion of data. - This will always return a new Data object. + This will always return a new Data object with + all tensors copied, even if no conversion is necessary. A torch.dtype and torch.device are inferred from the arguments of self.to(*args, **kwargs). Please have a look at the documentation of torch.Tensor.to() for more details. + + The conversion will be applied to all Tensor fields of the dataclass, + and to all fields that implement the MoveDataMixin. + + The dtype-type, i.e. float/complex/int will always be preserved. + Example: + If called with dtype=torch.float32 OR dtype=torch.complex64: + - A complex128 tensor will be converted to complex64 + - A float64 tensor will be converted to float32 + - A bool tensor will remain bool + If other conversions are desired, please use the torch.Tensor.to() method of + the fields directly. """ - kwargs_tensors = {**kwargs, 'copy': True} + _args: Sequence[Any] = () + _kwargs: dict[str, Any] = {} + dtype = None + device = None + + # match dtype and device from args and kwargs + match args, kwargs: + case ((dtype, *_args), {**_kwargs}) if isinstance(dtype, torch.dtype): + # overload 1 + ... + case (_args, {'dtype': dtype, **_kwargs}) if isinstance(dtype, torch.dtype): + # dtype as kwarg + ... + case ((other, *_args), {**_kwargs}) | (_args, {'other': other, **_kwargs}) if isinstance( + other, torch.Tensor + ): + # overload 3: use dtype and device from other + dtype = other.dtype + device = other.device + match args, kwargs: + case ((device, dtype, *_args), {**_kwargs}) if isinstance(device, torch.device | str) and isinstance( + dtype, torch.dtype + ): + # overload 2 with device and dtype + ... + case ((device, *_args), {**_kwargs}) if isinstance(device, torch.device | str): + # overload 2, only device + ... + case (_args, {'device': device, **_kwargs}) if isinstance(device, torch.device | str): + # device as kwarg + ... + + _kwargs['copy'] = True new_data: dict[str, Any] = {} for field in dataclasses.fields(self): name = field.name data = getattr(self, name) if isinstance(data, torch.Tensor): - new_data[name] = data.to(*args, **kwargs_tensors) + new_device = data.device if device is None else device + if dtype is None: + new_dtype = data.dtype + elif data.dtype.is_floating_point: + new_dtype = dtype.to_real() + elif data.dtype.is_complex: + new_dtype = dtype.to_complex() + else: + new_dtype = dtype + new_data[name] = data.to(new_device, new_dtype, *_args, **_kwargs) elif isinstance(data, MoveDataMixin): new_data[name] = data.to(*args, **kwargs) else: @@ -104,6 +162,35 @@ def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Sel return self.to(device='cpu', memory_format=memory_format) @property - def device(self) -> torch.device: - """Return the device of the data tensor.""" - return self.data.device + def device(self) -> torch.device | None: + """Return the device of the tensors. + + Looks at each field of a dataclass and returns fields implementing a device attribute, + such as torch.Tensors or MoveDataMixin instances. + + Raises + ------ + InconsistentDeviceError: + If the devices of different fields differ. + + Returns + ------- + The device of the fields or None if no field implements a device attribute. + """ + device: None | torch.device = None + for field in dataclasses.fields(self): + data = getattr(self, field.name) + if not hasattr(data, 'device'): + continue + current_device = data.getattr('device', None) + if current_device is None: + continue + if device is None: + device = current_device + elif device != current_device: + raise InconsistentDeviceError(current_device, device) + return device + + def clone(self: Self) -> Self: + """Return a deep copy of the object.""" + return deepcopy(self) diff --git a/src/mrpro/data/_QHeader.py b/src/mrpro/data/_QHeader.py index d93ab6453..637cffa9f 100644 --- a/src/mrpro/data/_QHeader.py +++ b/src/mrpro/data/_QHeader.py @@ -23,11 +23,12 @@ from mrpro.data._IHeader import IHeader from mrpro.data._KHeader import KHeader +from mrpro.data._MoveDataMixin import MoveDataMixin from mrpro.data._SpatialDimension import SpatialDimension @dataclass(slots=True) -class QHeader: +class QHeader(MoveDataMixin): """MR quantitative data header.""" # ToDo: decide which attributes to store in the header diff --git a/src/mrpro/data/_SpatialDimension.py b/src/mrpro/data/_SpatialDimension.py index 94d130b04..739e2e7ca 100644 --- a/src/mrpro/data/_SpatialDimension.py +++ b/src/mrpro/data/_SpatialDimension.py @@ -26,6 +26,8 @@ import torch from numpy.typing import ArrayLike +from mrpro.data._MoveDataMixin import MoveDataMixin + T = TypeVar('T', int, float, torch.Tensor) @@ -38,7 +40,7 @@ class XYZ(Protocol[T]): @dataclass(slots=True) -class SpatialDimension(Generic[T]): +class SpatialDimension(MoveDataMixin, Generic[T]): """Spatial dataclass of float/int/tensors (z, y, x).""" z: T diff --git a/src/mrpro/data/_kdata/_KData.py b/src/mrpro/data/_kdata/_KData.py index 590e6e146..364390afa 100644 --- a/src/mrpro/data/_kdata/_KData.py +++ b/src/mrpro/data/_kdata/_KData.py @@ -19,13 +19,7 @@ import dataclasses import datetime import warnings -from collections.abc import Sequence -from copy import deepcopy from pathlib import Path -from typing import Any -from typing import Protocol -from typing import Self -from typing import overload import h5py import ismrmrd @@ -242,106 +236,3 @@ def sort_and_reshape_tensor_fields(dataclass: AcqInfo | AcqIdx): ) from None return cls(kheader, kdata, ktrajectory_final) - - @overload - def to( - self, dtype: torch.dtype, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None - ) -> Self: ... - - @overload - def to( - self, - device: str | torch.device | int | None = None, - dtype: torch.dtype | None = None, - non_blocking: bool = False, - *, - memory_format: torch.memory_format | None = None, - ) -> Self: ... - - @overload - def to( - self, other: torch.Tensor, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None - ) -> Self: ... - - def to(self, *args, **kwargs) -> Self: # noqa: D417 - """Perform dtype and/or device conversion of trajectory and data. - - This will always return a copy. - - Parameters - ---------- - device - The destination device. Defaults to the current device. - dtype - Data type. - The trajectory dtype will always be converted to real, - the data dtype will always be converted to complex - non_blocking - If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. - Otherwise, the argument has no effect. - memory_format - The desired memory format of returned Tensor. - """ - _args: Sequence[Any] = () - _kwargs: dict[str, Any] = {} - dtype = self.data.dtype - device = self.device - match args, kwargs: - case ((dtype, *_args), {**_kwargs}) if isinstance(dtype, torch.dtype): - # overload 1 - ... - case (_args, {'dtype': dtype, **_kwargs}) if isinstance(dtype, torch.dtype): - # dtype as kwarg - ... - case ((other, *_args), {**_kwargs}) | (_args, {'other': other, **_kwargs}) if isinstance( - other, torch.Tensor - ): - # overload 3: use dtype and device from other - dtype = other.dtype - device = other.device - match args, kwargs: - case ((device, dtype, *_args), {**_kwargs}) if isinstance(device, torch.device | str) and isinstance( - dtype, torch.dtype - ): - # overload 2 with device and dtype - ... - case ((device, *_args), {**_kwargs}) if isinstance(device, torch.device | str): - # overload 2, only device - ... - case (_args, {'device': device, **_kwargs}) if isinstance(device, torch.device | str): - # device as kwarg - ... - - # The trajectory dtype will always be real, the data always be complex. - data = self.data.to(*_args, **{**_kwargs, 'dtype': dtype.to_complex(), 'device': device, 'copy': True}) - traj = self.traj.to(*_args, **{**_kwargs, 'dtype': dtype.to_real(), 'device': device}) - header = deepcopy(self.header) # TODO: use header.to - return type(self)(header=header, data=data, traj=traj) - - -class _KDataProtocol(Protocol): - """Protocol for KData used for type hinting in KData mixins. - - Note that the actual KData class can have more properties and methods than those defined here. - - If you want to use a property or method of KData in a new KDataMixin class, - you must add it to this Protocol to make sure that the type hinting works. - - For more information about Protocols see: - https://typing.readthedocs.io/en/latest/spec/protocol.html#protocols - """ - - @property - def header(self) -> KHeader: ... - - @property - def data(self) -> torch.Tensor: ... - - @property - def traj(self) -> KTrajectory: ... - - def __init__(self, header: KHeader, data: torch.Tensor, traj: KTrajectory): ... - - def _split_k2_or_k1_into_other( - self, split_idx: torch.Tensor, other_label: str, split_dir: str - ) -> _KDataProtocol: ... From 3077800f0420a4b9714138de1460cf5855498afa Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 13 May 2024 14:39:27 +0000 Subject: [PATCH 10/19] Update [ghstack-poisoned] --- src/mrpro/data/_kdata/_KData.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/mrpro/data/_kdata/_KData.py b/src/mrpro/data/_kdata/_KData.py index 364390afa..51bc6d92e 100644 --- a/src/mrpro/data/_kdata/_KData.py +++ b/src/mrpro/data/_kdata/_KData.py @@ -20,6 +20,7 @@ import datetime import warnings from pathlib import Path +from typing import Protocol import h5py import ismrmrd @@ -236,3 +237,31 @@ def sort_and_reshape_tensor_fields(dataclass: AcqInfo | AcqIdx): ) from None return cls(kheader, kdata, ktrajectory_final) + + +class _KDataProtocol(Protocol): + """Protocol for KData used for type hinting in KData mixins. + + Note that the actual KData class can have more properties and methods than those defined here. + + If you want to use a property or method of KData in a new KDataMixin class, + you must add it to this Protocol to make sure that the type hinting works. + + For more information about Protocols see: + https://typing.readthedocs.io/en/latest/spec/protocol.html#protocols + """ + + @property + def header(self) -> KHeader: ... + + @property + def data(self) -> torch.Tensor: ... + + @property + def traj(self) -> KTrajectory: ... + + def __init__(self, header: KHeader, data: torch.Tensor, traj: KTrajectory): ... + + def _split_k2_or_k1_into_other( + self, split_idx: torch.Tensor, other_label: str, split_dir: str + ) -> _KDataProtocol: ... From 4b281206c35c94c08d59a49949881f7c3c4b7ed0 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 13 May 2024 16:22:27 +0000 Subject: [PATCH 11/19] Update [ghstack-poisoned] --- src/mrpro/data/_AcqInfo.py | 2 +- src/mrpro/data/_MoveDataMixin.py | 50 +++++++++++-------- tests/data/test_kdata.py | 82 +++++++++++++++++++++++++++++++- 3 files changed, 112 insertions(+), 22 deletions(-) diff --git a/src/mrpro/data/_AcqInfo.py b/src/mrpro/data/_AcqInfo.py index a29e0c727..7caee31d6 100644 --- a/src/mrpro/data/_AcqInfo.py +++ b/src/mrpro/data/_AcqInfo.py @@ -51,7 +51,7 @@ class AcqIdx(MoveDataMixin): @dataclass(slots=True) -class AcqInfo: +class AcqInfo(MoveDataMixin): """Acquisition information for each readout.""" idx: AcqIdx diff --git a/src/mrpro/data/_MoveDataMixin.py b/src/mrpro/data/_MoveDataMixin.py index 22a71ab06..0804ebfb0 100644 --- a/src/mrpro/data/_MoveDataMixin.py +++ b/src/mrpro/data/_MoveDataMixin.py @@ -69,19 +69,23 @@ def to(self, *args, **kwargs) -> Self: If other conversions are desired, please use the torch.Tensor.to() method of the fields directly. """ - _args: Sequence[Any] = () - _kwargs: dict[str, Any] = {} - dtype = None - device = None + other_args: Sequence[Any] = () + other_kwargs: dict[str, Any] = {} + dtype: torch.dtype | None = None + device: torch.device | str | None = None # match dtype and device from args and kwargs match args, kwargs: - case ((dtype, *_args), {**_kwargs}) if isinstance(dtype, torch.dtype): + case ((_dtype, *_args), {**_kwargs}) if isinstance(_dtype, torch.dtype): # overload 1 - ... - case (_args, {'dtype': dtype, **_kwargs}) if isinstance(dtype, torch.dtype): + dtype = _dtype + other_args = _args + other_kwargs = _kwargs + case (_args, {'dtype': _dtype, **_kwargs}) if isinstance(_dtype, torch.dtype): # dtype as kwarg - ... + dtype = _dtype + other_args = _args + other_kwargs = _kwargs case ((other, *_args), {**_kwargs}) | (_args, {'other': other, **_kwargs}) if isinstance( other, torch.Tensor ): @@ -89,19 +93,26 @@ def to(self, *args, **kwargs) -> Self: dtype = other.dtype device = other.device match args, kwargs: - case ((device, dtype, *_args), {**_kwargs}) if isinstance(device, torch.device | str) and isinstance( - dtype, torch.dtype + case ((_device, _dtype, *_args), {**_kwargs}) if isinstance(_device, torch.device | str) and isinstance( + _dtype, torch.dtype ): # overload 2 with device and dtype - ... - case ((device, *_args), {**_kwargs}) if isinstance(device, torch.device | str): + dtype = _dtype + device = _device + other_args = _args + other_kwargs = _kwargs + case ((_device, *_args), {**_kwargs}) if isinstance(_device, torch.device | str): # overload 2, only device - ... - case (_args, {'device': device, **_kwargs}) if isinstance(device, torch.device | str): + device = _device + other_args = _args + other_kwargs = _kwargs + case (_args, {'device': _device, **_kwargs}) if isinstance(_device, torch.device | str): # device as kwarg - ... + device = _device + other_args = _args + other_kwargs = _kwargs - _kwargs['copy'] = True + other_kwargs['copy'] = True new_data: dict[str, Any] = {} for field in dataclasses.fields(self): name = field.name @@ -115,8 +126,9 @@ def to(self, *args, **kwargs) -> Self: elif data.dtype.is_complex: new_dtype = dtype.to_complex() else: - new_dtype = dtype - new_data[name] = data.to(new_device, new_dtype, *_args, **_kwargs) + # bool or int: keep as is + new_dtype = data.dtype + new_data[name] = data.to(new_device, new_dtype, *other_args, **other_kwargs) elif isinstance(data, MoveDataMixin): new_data[name] = data.to(*args, **kwargs) else: @@ -182,7 +194,7 @@ def device(self) -> torch.device | None: data = getattr(self, field.name) if not hasattr(data, 'device'): continue - current_device = data.getattr('device', None) + current_device = getattr(data, 'device', None) if current_device is None: continue if device is None: diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 43e732684..ea9fadf72 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -153,16 +153,61 @@ def test_KData_modify_header(ismrmrd_cart, field, value): assert getattr(kdata.header, field) == value -def test_KData_to_complex128(ismrmrd_cart): - """Change KData dtype complex128.""" +def test_KData_to_float64tensor(ismrmrd_cart): + """Change KData dtype to double using other-tensor overload.""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata_float64 = kdata.to(torch.ones(1, dtype=torch.float64)) + assert kdata is not kdata_float64 + assert kdata_float64.data.dtype == torch.complex128 + torch.testing.assert_close(kdata_float64.data.to(dtype=torch.complex64), kdata.data) + + +@pytest.mark.cuda() +def test_KData_to_cudatensor(ismrmrd_cart): + """Move KData to cuda using other-tensor overload.""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata_cuda = kdata.to(torch.ones(1, device=torch.device('cuda'))) + assert kdata is not kdata_cuda + assert kdata_cuda.data.dtype == torch.complex64 + assert kdata_cuda.data.is_cuda + + +def test_Kdata_to_same(ismrmrd_cart): + """Call .to with no change in dtype or device.""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata2 = kdata.to() + assert kdata is not kdata2 + assert torch.equal(kdata.data, kdata2.data) + assert kdata2.data.dtype == kdata.data.dtype + assert kdata2.data.device == kdata.data.device + + +def test_KData_to_complex128_data(ismrmrd_cart): + """Change KData dtype complex128: test data.""" kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) kdata_complex128 = kdata.to(dtype=torch.complex128) + assert kdata is not kdata_complex128 assert kdata_complex128.data.dtype == torch.complex128 + torch.testing.assert_close(kdata_complex128.data.to(dtype=torch.complex64), kdata.data) + + +def test_KData_to_complex128_traj(ismrmrd_cart): + """Change KData dtype complex128: test trajectory.""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata_complex128 = kdata.to(dtype=torch.complex128) assert kdata_complex128.traj.kx.dtype == torch.float64 assert kdata_complex128.traj.ky.dtype == torch.float64 assert kdata_complex128.traj.kz.dtype == torch.float64 +def test_KData_to_complex128_header(ismrmrd_cart): + """Change KData dtype complex128: test header""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata_complex128 = kdata.to(dtype=torch.complex128) + assert kdata_complex128.header.acq_info.user_float.dtype == torch.float64 + assert kdata_complex128.header.acq_info.user_int.dtype == torch.int32 + + @pytest.mark.cuda() def test_KData_to_cuda(ismrmrd_cart): """Test KData.to to move data to CUDA memory.""" @@ -184,6 +229,9 @@ def test_KData_cuda(ismrmrd_cart): assert kdata_cuda.traj.kz.is_cuda assert kdata_cuda.traj.ky.is_cuda assert kdata_cuda.traj.kx.is_cuda + assert kdata_cuda.header.acq_info.user_int.is_cuda + assert kdata_cuda.device == torch.cuda.current_device() + assert kdata_cuda.header.acq_info.device == torch.cuda.current_device() @pytest.mark.cuda() @@ -195,6 +243,36 @@ def test_KData_cpu(ismrmrd_cart): assert kdata_cpu.traj.kz.is_cpu assert kdata_cpu.traj.ky.is_cpu assert kdata_cpu.traj.kx.is_cpu + assert kdata_cpu.header.acq_info.user_int.is_cpu + assert kdata_cpu.device == torch.device('cpu') + assert kdata_cpu.header.acq_info.device == torch.device('cpu') + + +def test_Kdata_device_cpu(ismrmrd_cart): + """Default device is CPU.""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + assert kdata.device == torch.device('cpu') + + +@pytest.mark.cuda() +def test_KData_inconsistentdevice(ismrmrd_cart): + """Inconsistent device raises exception.""" + kdata_cpu = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata_cuda = kdata_cpu.to(device='cuda') + kdata_mix = KData(data=kdata_cuda.data, header=kdata_cpu.header, traj=kdata_cpu.traj) + with pytest.raises(ValueError): + _ = kdata_mix.device + + +def test_KData_clone(ismrmrd_cart): + """Test .clone method.""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata2 = kdata.clone() + assert kdata is not kdata2 + assert kdata.data is not kdata2.data + assert torch.equal(kdata.data, kdata2.data) + assert kdata.traj.kx is not kdata2.traj.kx + assert torch.equal(kdata.traj.kx, kdata2.traj.kx) def test_KData_rearrange_k2_k1_into_k1(consistently_shaped_kdata): From 8459bc667cd0636dca2bab7d001ec6591f878d9f Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 14 May 2024 11:48:34 +0000 Subject: [PATCH 12/19] Update [ghstack-poisoned] --- src/mrpro/data/_KTrajectory.py | 10 ++++++--- src/mrpro/operators/_CartesianSamplingOp.py | 6 +++--- tests/data/test_trajectory.py | 23 +++++++++++++++++++-- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/mrpro/data/_KTrajectory.py b/src/mrpro/data/_KTrajectory.py index 27f837ab8..4f71b1862 100644 --- a/src/mrpro/data/_KTrajectory.py +++ b/src/mrpro/data/_KTrajectory.py @@ -55,11 +55,15 @@ class KTrajectory: def __post_init__(self) -> None: """Reduce repeated dimensions to singletons.""" + + def as_any_float(tensor: torch.Tensor) -> torch.Tensor: + return tensor.float() if not tensor.is_floating_point() else tensor + if self.repeat_detection_tolerance is not None: kz, ky, kx = ( - remove_repeat(tensor, self.repeat_detection_tolerance) for tensor in (self.kz, self.ky, self.kx) + as_any_float(remove_repeat(tensor, self.repeat_detection_tolerance)) + for tensor in (self.kz, self.ky, self.kx) ) - # use of setattr due to frozen dataclass object.__setattr__(self, 'kz', kz) object.__setattr__(self, 'ky', ky) @@ -156,7 +160,7 @@ def _traj_types( # We use the value of the enum-type to make it easier to do array operations. traj_type_matrix = torch.zeros(3, 3, dtype=torch.int) for ind, ks in enumerate((self.kz, self.ky, self.kx)): - values_on_grid = not ks.is_floating_point() or torch.all(ks.frac() <= tolerance) + values_on_grid = not ks.is_floating_point() or torch.all((ks - ks.round()).abs() <= tolerance) for dim in (-3, -2, -1): if ks.shape[dim] == 1: traj_type_matrix[ind, dim] |= TrajType.SINGLEVALUE.value | TrajType.ONGRID.value diff --git a/src/mrpro/operators/_CartesianSamplingOp.py b/src/mrpro/operators/_CartesianSamplingOp.py index 5e82a9c59..e993f4c20 100644 --- a/src/mrpro/operators/_CartesianSamplingOp.py +++ b/src/mrpro/operators/_CartesianSamplingOp.py @@ -58,19 +58,19 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> # in it and the shape of data will remain. # only dimensions on a cartesian grid will be reordered. if traj_type_kzyx[-1] == TrajType.ONGRID: # kx - kx_idx = ktraj_tensor[-1, ...].to(dtype=torch.int64) + sorted_grid_shape.x // 2 + kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] kx_idx = rearrange(torch.arange(ktraj_tensor.shape[-1]), 'kx->1 1 1 kx') if traj_type_kzyx[-2] == TrajType.ONGRID: # ky - ky_idx = ktraj_tensor[-2, ...].to(dtype=torch.int64) + sorted_grid_shape.y // 2 + ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] ky_idx = rearrange(torch.arange(ktraj_tensor.shape[-2]), 'ky->1 1 ky 1') if traj_type_kzyx[-3] == TrajType.ONGRID: # kz - kz_idx = ktraj_tensor[-3, ...].to(dtype=torch.int64) + sorted_grid_shape.z // 2 + kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] kz_idx = rearrange(torch.arange(ktraj_tensor.shape[-3]), 'kz->1 kz 1 1') diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 2eb5f954b..dc15480d8 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -88,7 +88,7 @@ def test_trajectory_tensor_conversion(cartesian_grid): n_k2 = 30 kz_full, ky_full, kx_full = cartesian_grid(n_k2, n_k1, n_k0, jitter=0.0) trajectory = KTrajectory(kz_full, ky_full, kx_full) - tensor = torch.stack((kz_full, ky_full, kx_full), dim=0) + tensor = torch.stack((kz_full, ky_full, kx_full), dim=0).to(torch.float32) tensor_from_traj = trajectory.as_tensor() # stack_dim=0 tensor_from_traj_dim2 = trajectory.as_tensor(stack_dim=2).moveaxis(2, 0) @@ -123,11 +123,30 @@ def test_trajectory_to_float64(cartesian_grid): n_k2 = 30 kz_full, ky_full, kx_full = cartesian_grid(n_k2, n_k1, n_k0, jitter=0.0) trajectory = KTrajectory(kz_full, ky_full, kx_full) - trajectory_float64 = trajectory.to(dtype=torch.float64) assert trajectory_float64.kz.dtype == torch.float64 assert trajectory_float64.ky.dtype == torch.float64 assert trajectory_float64.kx.dtype == torch.float64 + assert trajectory.kz.dtype == torch.float32 + assert trajectory.ky.dtype == torch.float32 + assert trajectory.kx.dtype == torch.float32 + + +@pytest.mark.parametrize('dtype', [torch.float32, torch.float64, torch.int32, torch.int64]) +def test_trajectory_floating_dtype(dtype): + """Test if the trajectory will always be converted to float""" + ks = torch.ones(3, 1, 1, 1, 1, dtype=dtype) + traj = KTrajectory(*ks) + if dtype.is_floating_point: + # keep as as + assert traj.kz.dtype == dtype + assert traj.ky.dtype == dtype + assert traj.kx.dtype == dtype + else: + # convert to float32 + assert traj.kz.dtype == torch.float32 + assert traj.ky.dtype == torch.float32 + assert traj.kx.dtype == torch.float32 @pytest.mark.cuda() From 59634e2e15afaeb7bee8850d1fd6324dde879000 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 14 May 2024 19:42:54 +0000 Subject: [PATCH 13/19] Update [ghstack-poisoned] --- src/mrpro/data/_MoveDataMixin.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/mrpro/data/_MoveDataMixin.py b/src/mrpro/data/_MoveDataMixin.py index 53ee592ba..a8bdb7a14 100644 --- a/src/mrpro/data/_MoveDataMixin.py +++ b/src/mrpro/data/_MoveDataMixin.py @@ -229,3 +229,40 @@ def device(self) -> torch.device | None: def clone(self: Self) -> Self: """Return a deep copy of the object.""" return deepcopy(self) + + @property + def is_cuda(self) -> bool: + """Return True if all tensors are on a single CUDA device. + + Checks all tensor attributes of the dataclass for their device, + (recursively if an attribute is a MoveDataMixin) + + + Returns False if not all tensors are on the same CUDA devices, or if the device is inconsistent, + returns True if the data class has no tensors as attributes. + """ + try: + device = self.device + except InconsistentDeviceError: + return False + if device is None: + return True + return device.type == 'cuda' + + @property + def is_cpu(self) -> bool: + """Return True if all tensors are on the CPU. + + Checks all tensor attributes of the dataclass for their device, + (recursively if an attribute is a MoveDataMixin) + + Returns False if not all tensors are on cpu or if the device is inconsistent, + returns True if the data class has no tensors as attributes. + """ + try: + device = self.device + except InconsistentDeviceError: + return False + if device is None: + return True + return device.type == 'cpu' From 1c8e3985ddbb26543f413faab63ee23cf47ec8a4 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 14 May 2024 19:46:24 +0000 Subject: [PATCH 14/19] Update [ghstack-poisoned] --- tests/data/test_kdata.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index dd4872a3a..d0282cc57 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -232,6 +232,8 @@ def test_KData_cuda(ismrmrd_cart): assert kdata_cuda.header.acq_info.user_int.is_cuda assert kdata_cuda.device == torch.device(torch.cuda.current_device()) assert kdata_cuda.header.acq_info.device == torch.device(torch.cuda.current_device()) + assert kdata_cuda.is_cuda + assert not kdata_cuda.is_cpu @pytest.mark.cuda() @@ -252,6 +254,8 @@ def test_Kdata_device_cpu(ismrmrd_cart): """Default device is CPU.""" kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) assert kdata.device == torch.device('cpu') + assert not kdata.is_cuda + assert kdata.is_cpu @pytest.mark.cuda() @@ -260,6 +264,8 @@ def test_KData_inconsistentdevice(ismrmrd_cart): kdata_cpu = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) kdata_cuda = kdata_cpu.to(device='cuda') kdata_mix = KData(data=kdata_cuda.data, header=kdata_cpu.header, traj=kdata_cpu.traj) + assert not kdata_mix.is_cuda + assert not kdata_mix.is_cpu with pytest.raises(ValueError): _ = kdata_mix.device From d2e7576e2d029b2f9ba844967c48e6a8f16320d2 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Wed, 22 May 2024 22:05:05 +0000 Subject: [PATCH 15/19] Update [ghstack-poisoned] --- src/mrpro/data/_MoveDataMixin.py | 298 +++++++++++++++++++-------- tests/data/test_kdata.py | 12 +- tests/data/test_movedatamixin.py | 115 +++++++++++ tests/data/test_spatial_dimension.py | 34 +++ tests/data/test_trajectory.py | 10 + 5 files changed, 381 insertions(+), 88 deletions(-) create mode 100644 tests/data/test_movedatamixin.py diff --git a/src/mrpro/data/_MoveDataMixin.py b/src/mrpro/data/_MoveDataMixin.py index a8bdb7a14..53aead69a 100644 --- a/src/mrpro/data/_MoveDataMixin.py +++ b/src/mrpro/data/_MoveDataMixin.py @@ -17,14 +17,16 @@ from __future__ import annotations import dataclasses -from abc import ABC -from collections.abc import Sequence +from collections.abc import Iterator +from copy import copy as shallowcopy from copy import deepcopy from typing import Any from typing import ClassVar from typing import Protocol from typing import Self +from typing import TypeAlias from typing import overload +from typing import runtime_checkable import torch @@ -34,47 +36,59 @@ def __init__(self, *devices): super().__init__(f'Inconsistent devices found, found at least {", ".join(str(d) for d in devices)}') +@runtime_checkable class DataclassInstance(Protocol): """An instance of a dataclass.""" __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] -class MoveDataMixin(ABC, DataclassInstance): +class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @overload def to( - self, dtype: torch.dtype, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + self, + device: str | torch.device | int | None = None, + dtype: torch.dtype | None = None, + non_blocking: bool = False, + *, + copy: bool = False, + memory_format: torch.memory_format | None = None, ) -> Self: ... @overload def to( self, - device: str | torch.device | int | None = None, - dtype: torch.dtype | None = None, + dtype: torch.dtype, non_blocking: bool = False, *, + copy: bool = False, memory_format: torch.memory_format | None = None, ) -> Self: ... @overload def to( - self, other: torch.Tensor, non_blocking: bool = False, *, memory_format: torch.memory_format | None = None + self, + tensor: torch.Tensor, + non_blocking: bool = False, + *, + copy: bool = False, + memory_format: torch.memory_format | None = None, ) -> Self: ... def to(self, *args, **kwargs) -> Self: """Perform dtype and/or device conversion of data. - This will always return a new Data object with - all tensors copied, even if no conversion is necessary. - A torch.dtype and torch.device are inferred from the arguments of self.to(*args, **kwargs). Please have a look at the documentation of torch.Tensor.to() for more details. - The conversion will be applied to all Tensor fields of the dataclass, - and to all fields that implement the MoveDataMixin. + A new instance of the dataclass will be returned. + + The conversion will be applied to all Tensor- or Module + fields of the dataclass, and to all fields that implement + the MoveDataMixin. The dtype-type, i.e. float/complex will always be preserved, but the precision of floating point dtypes might be changed. @@ -88,83 +102,145 @@ def to(self, *args, **kwargs) -> Self: If other conversions are desired, please use the torch.Tensor.to() method of the fields directly. + + If the copy argument is set to True (default), a deep copy will be returned + even if no conversion is necessary. + If two fields are views of the same data before, in the result they will be independent + copies if copy is set to True or a conversion is necessary. + If set to False, some Tensors might be shared between the original and the new object. """ - other_args: Sequence[Any] = () - other_kwargs: dict[str, Any] = {} - dtype: torch.dtype | None = None - device: torch.device | str | int | None = None - - # match dtype and device from args and kwargs - match args, kwargs: - case ((_dtype, *_args), {**_kwargs}) if isinstance(_dtype, torch.dtype): - # overload 1 - dtype = _dtype - other_args = _args - other_kwargs = _kwargs - case (_args, {'dtype': _dtype, **_kwargs}) if isinstance(_dtype, torch.dtype): - # dtype as kwarg - dtype = _dtype - other_args = _args - other_kwargs = _kwargs - case ((other, *_args), {**_kwargs}) | (_args, {'other': other, **_kwargs}) if isinstance( - other, torch.Tensor - ): - # overload 3: use dtype and device from other - dtype = other.dtype - device = other.device - match args, kwargs: - case ((_device, _dtype, *_args), {**_kwargs}) if isinstance( - _device, torch.device | str | int | None - ) and isinstance(_dtype, torch.dtype): - # overload 2 with device and dtype - dtype = _dtype - device = _device - other_args = _args - other_kwargs = _kwargs - case ((_device, *_args), {**_kwargs}) if isinstance(_device, torch.device | str | int | None): - # overload 2, only device - device = _device - other_args = _args - other_kwargs = _kwargs - case (_args, {'device': _device, **_kwargs}) if isinstance(_device, torch.device | str | int | None): - # device as kwarg - device = _device - other_args = _args - other_kwargs = _kwargs - - other_kwargs['copy'] = True - new_data: dict[str, Any] = {} - for field in dataclasses.fields(self): - name = field.name - data = getattr(self, name) + # Parse the arguments of the three overloads and call _to with the parsed arguments + parsedType: TypeAlias = tuple[ + str | torch.device | int | None, torch.dtype | None, bool, bool, torch.memory_format + ] + + def parse3( + other: torch.Tensor, + non_blocking: bool = False, + copy: bool = False, + ) -> parsedType: + return other.device, other.dtype, non_blocking, copy, torch.preserve_format + + def parse1( + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: torch.memory_format = torch.preserve_format, + ) -> parsedType: + return None, dtype, non_blocking, copy, memory_format + + def parse2( + device: str | torch.device | int | None = None, + dtype: None | torch.dtype = None, + non_blocking: bool = False, + copy: bool = False, + memory_format: torch.memory_format = torch.preserve_format, + ) -> parsedType: + return device, dtype, non_blocking, copy, memory_format + + if args and isinstance(args[0], torch.Tensor) or 'other' in kwargs: + # overload 3 + device, dtype, non_blocking, copy, memory_format = parse3(*args, **kwargs) + elif args and isinstance(args[0], torch.dtype): + # overload 2 + device, dtype, non_blocking, copy, memory_format = parse1(*args, **kwargs) + else: + # overload 1 + device, dtype, non_blocking, copy, memory_format = parse2(*args, **kwargs) + return self._to(device=device, dtype=dtype, non_blocking=non_blocking, memory_format=memory_format, copy=copy) + + def _items(self) -> Iterator[tuple[str, Any]]: + if isinstance(self, DataclassInstance): + for field in dataclasses.fields(self): + name = field.name + data = getattr(self, name) + yield name, data + if isinstance(self, torch.nn.Module): + yield from self._parameters.items() + yield from self._buffers.items() + yield from self._modules.items() + + def _to( + self, + device: torch.device | str | int | None = None, + dtype: torch.dtype | None = None, + non_blocking: bool = False, + memory_format: torch.memory_format = torch.preserve_format, + shared_memory: bool = False, + copy: bool = False, + memo: dict | None = None, + ) -> Self: + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + + if memo is None: + memo = {} + + def _tensor_to(data: torch.Tensor) -> torch.Tensor: + """Move tensor to device and convert dtype if necessary.""" + new_dtype: torch.dtype | None + if dtype is not None and data.dtype.is_floating_point: + new_dtype = dtype.to_real() + elif dtype is not None and data.dtype.is_complex: + new_dtype = dtype.to_complex() + else: + # bool or int: keep as is + new_dtype = None + data = data.to( + device, + new_dtype, + non_blocking=non_blocking, + memory_format=memory_format, + copy=copy, + ) + if shared_memory: + data.share_memory_() + return data + + def _module_to(data: torch.nn.Module) -> torch.nn.Module: + if copy: + data = deepcopy(data) + return data._apply(_tensor_to, recurse=True) + + def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: + return obj._to( + device=device, + dtype=dtype, + non_blocking=non_blocking, + memory_format=memory_format, + shared_memory=shared_memory, + copy=copy, + memo=memo, + ) + + converted: Any + for name, data in new._items(): + if id(data) in memo: + object.__setattr__(new, name, memo[id(data)]) + continue if isinstance(data, torch.Tensor): - new_device = data.device if device is None else device - if dtype is None: - new_dtype = data.dtype - elif data.dtype.is_floating_point: - new_dtype = dtype.to_real() - elif data.dtype.is_complex: - new_dtype = dtype.to_complex() - else: - # bool or int: keep as is - new_dtype = data.dtype - new_data[name] = data.to(new_device, new_dtype, *other_args, **other_kwargs) + converted = _tensor_to(data) elif isinstance(data, MoveDataMixin): - new_data[name] = data.to(*args, **kwargs) + converted = _mixin_to(data) + elif isinstance(data, torch.nn.Module): + converted = _module_to(data) + elif copy: + converted = deepcopy(data) else: - new_data[name] = deepcopy(data) - return type(self)(**new_data) + converted = data + memo[id(data)] = converted + # this works even if new is frozen + object.__setattr__(new, name, converted) + return new def cuda( self, device: torch.device | str | int | None = None, + *, non_blocking: bool = False, memory_format: torch.memory_format = torch.preserve_format, + copy: bool = False, ) -> Self: - """Create copy of object with data in CUDA memory. - - This will always return a copy. - + """Put object in CUDA memory. Parameters ---------- @@ -175,23 +251,74 @@ def cuda( Otherwise, the argument has no effect. memory_format The desired memory format of returned tensor. + copy: + If True, the returned tensor will always be a copy, even if the input was already on the correct device. + This will also create new tensors for views """ if device is None: device = torch.device(torch.cuda.current_device()) - return self.to(device=device, memory_format=memory_format, non_blocking=non_blocking) + return self._to(device=device, dtype=None, memory_format=memory_format, non_blocking=non_blocking, copy=copy) + + def cpu(self, *, memory_format: torch.memory_format = torch.preserve_format, copy: bool = False) -> Self: + """Put in CPU memory. + + Parameters + ---------- + memory_format + The desired memory format of returned tensor. + copy + If True, the returned tensor will always be a copy, even if the input was already on the correct device. + This will also create new tensors for views + """ + return self._to(device='cpu', dtype=None, non_blocking=True, memory_format=memory_format, copy=copy) + + def double(self, *, memory_format: torch.memory_format = torch.preserve_format, copy: bool = False) -> Self: + """Convert all float tensors to double precision. + + converts float to float64 and complex to complex128 + + + Parameters + ---------- + memory_format + The desired memory format of returned tensor. + copy + If True, the returned tensor will always be a copy, even if the input was already on the correct device. + This will also create new tensors for views + """ + return self._to(dtype=torch.float64, memory_format=memory_format, copy=copy) + + def half(self, *, memory_format: torch.memory_format = torch.preserve_format, copy: bool = False) -> Self: + """Convert all float tensors to haf precision. + + converts float to float16 and complex to complex32 + + + Parameters + ---------- + memory_format + The desired memory format of returned tensor. + copy + If True, the returned tensor will always be a copy, even if the input was already on the correct device. + This will also create new tensors for views + """ + return self._to(dtype=torch.float16, memory_format=memory_format, copy=copy) - def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Self: - """Create copy of object in CPU memory. + def single(self, *, memory_format: torch.memory_format = torch.preserve_format, copy: bool = False) -> Self: + """Convert all float tensors to single precision. - This will always return a copy. + converts float to float32 and complex to complex64 Parameters ---------- memory_format The desired memory format of returned tensor. + copy + If True, the returned tensor will always be a copy, even if the input was already on the correct device. + This will also create new tensors for views """ - return self.to(device='cpu', memory_format=memory_format) + return self._to(dtype=torch.float32, memory_format=memory_format, copy=copy) @property def device(self) -> torch.device | None: @@ -213,8 +340,7 @@ def device(self) -> torch.device | None: The device of the fields or None if no field implements a device attribute. """ device: None | torch.device = None - for field in dataclasses.fields(self): - data = getattr(self, field.name) + for _, data in self._items(): if not hasattr(data, 'device'): continue current_device = getattr(data, 'device', None) @@ -228,7 +354,7 @@ def device(self) -> torch.device | None: def clone(self: Self) -> Self: """Return a deep copy of the object.""" - return deepcopy(self) + return self._to(device=None, dtype=None, non_blocking=False, memory_format=torch.preserve_format, copy=True) @property def is_cuda(self) -> bool: diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index d0282cc57..25c7b2c26 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -172,16 +172,24 @@ def test_KData_to_cudatensor(ismrmrd_cart): assert kdata_cuda.data.is_cuda -def test_Kdata_to_same(ismrmrd_cart): +def test_Kdata_to_same_copy(ismrmrd_cart): """Call .to with no change in dtype or device.""" kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) - kdata2 = kdata.to() + kdata2 = kdata.to(copy=True) assert kdata is not kdata2 assert torch.equal(kdata.data, kdata2.data) assert kdata2.data.dtype == kdata.data.dtype assert kdata2.data.device == kdata.data.device +def test_Kdata_to_same_nocopy(ismrmrd_cart): + """Call .to with no change in dtype or device.""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata2 = kdata.to(copy=False) + assert kdata is not kdata2 + assert kdata.data is kdata2.data + + def test_KData_to_complex128_data(ismrmrd_cart): """Change KData dtype complex128: test data.""" kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py new file mode 100644 index 000000000..8d24d06c4 --- /dev/null +++ b/tests/data/test_movedatamixin.py @@ -0,0 +1,115 @@ +"""Tests the MoveDataMixin class.""" + +# Copyright 2024 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from dataclasses import field +from typing import Any + +import pytest +import torch +from mrpro.data import MoveDataMixin + + +class SharedModule(torch.nn.Module): + """A module with two submodules that share the same parameters.""" + + def __init__(self): + super().__init__() + self.module1 = torch.nn.Linear(1, 1) + self.module2 = torch.nn.Linear(1, 1) + self.module2.weight = self.module1.weight + + +@dataclass(slots=True) +class A(MoveDataMixin): + """Test class A.""" + + floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) + inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) + booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) + module: torch.nn.Module = field(default_factory=lambda: torch.nn.Linear(1, 1)) + + +@dataclass(frozen=True) +class B(MoveDataMixin): + """Test class B.""" + + child: A = field(default_factory=A) + module: torch.nn.Module = field(default_factory=SharedModule) + floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) + inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) + booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) + + +def _test( + original: Any, + new: Any, + attribute: str, + copy: bool, + expected_dtype: torch.dtype, + expected_device: torch.device, +) -> None: + """Assertion used in the tests below. + + Compares the attribute of the original and new object. + Checks device, dtype and if the data is copied if required + on the new object. + """ + original_data = getattr(original, attribute) + new_data = getattr(new, attribute) + if copy: + assert new_data is not original_data, 'copy requested but not performed' + assert torch.equal(new_data, original_data.to(device=expected_device, dtype=expected_dtype)) + assert new_data.device == expected_device, 'device not set correctly' + assert new_data.dtype == expected_dtype, 'dtype not set correctly' + + +@pytest.mark.parametrize('dtype', [torch.float64, torch.complex128]) +@pytest.mark.parametrize('copy', [True, False]) +def test_movedatamixin_float64like(copy: bool, dtype: torch.dtype): + original = B() + new = original.to(dtype=dtype, copy=copy) + + # Tensor attributes + def test(attribute, expected_dtype): + return _test(original, new, attribute, copy, expected_dtype, torch.device('cpu')) + + test('floattensor', torch.float64) + test('complextensor', torch.complex128) + test('inttensor', torch.int32) + test('booltensor', torch.bool) + + # Attributes of child + def testchild(attribute, expected_dtype): + return _test(original.child, new.child, attribute, copy, expected_dtype, torch.device('cpu')) + + testchild('floattensor', torch.float64) + testchild('complextensor', torch.complex128) + testchild('inttensor', torch.int32) + testchild('booltensor', torch.bool) + + # Module attribute + _test(original.child.module, new.child.module, 'weight', copy, torch.float64, torch.device('cpu')) + + # No-copy required for these + if not copy: + assert original.inttensor is new.inttensor, 'no copy of inttensor required' + assert original.booltensor is new.booltensor, 'no copy of booltensor required' + assert original.child.inttensor is new.child.inttensor, 'no copy of inttensor required' + assert original.child.booltensor is new.child.booltensor, 'no copy of booltensor required' + assert original is not new, 'original and new should not be the same object' + + assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index d5de1833c..12c716565 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -97,3 +97,37 @@ def test_spatial_dimension_zyx(): spatial_dimension = SpatialDimension(z=z, y=y, x=x) assert isinstance(spatial_dimension.zyx, tuple) assert spatial_dimension.zyx == (z, y, x) + + +@pytest.mark.cuda() +def test_spatial_dimension_cuda_tensor(): + """Test moving to CUDA""" + spatial_dimension = SpatialDimension(z=torch.ones(1), y=torch.ones(1), x=torch.ones(1)) + spatial_dimension_cuda = spatial_dimension.cuda() + assert spatial_dimension_cuda.z.is_cuda + assert spatial_dimension_cuda.y.is_cuda + assert spatial_dimension_cuda.x.is_cuda + assert spatial_dimension.z.is_cpu + assert spatial_dimension.y.is_cpu + assert spatial_dimension.x.is_cpu + assert spatial_dimension_cuda.is_cuda + assert spatial_dimension.is_cpu + assert not spatial_dimension_cuda.is_cpu + assert not spatial_dimension.is_cuda + + +def test_spatial_dimension_cuda_float(): + """Test moving to CUDA without tensors -> copy only""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + # the device number should not matter, has there is no + # data to move to the device + spatial_dimension_cuda = spatial_dimension.cuda(42) + # if a dataclass has no tensors, it is both on CPU and CUDA + # and the device is None + assert spatial_dimension_cuda.is_cuda + assert spatial_dimension.is_cpu + assert spatial_dimension_cuda.is_cpu + assert spatial_dimension.is_cuda + assert spatial_dimension.device is None + assert spatial_dimension_cuda.device is None + assert spatial_dimension_cuda is not spatial_dimension diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index dc15480d8..5dd2c79e5 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -163,6 +163,16 @@ def test_trajectory_cuda(cartesian_grid): assert trajectory_cuda.ky.is_cuda assert trajectory_cuda.kx.is_cuda + assert trajectory.kz.is_cpu + assert trajectory.ky.is_cpu + assert trajectory.kx.is_cpu + + assert trajectory_cuda.is_cuda + assert trajectory.is_cpu + + assert not trajectory_cuda.is_cpu + assert not trajectory.is_cuda + @pytest.mark.cuda() def test_trajectory_cpu(cartesian_grid): From 1041ca7405c91a346b5fc441058dfc38ea807496 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Wed, 22 May 2024 22:12:12 +0000 Subject: [PATCH 16/19] Update [ghstack-poisoned] --- pyproject.toml | 1 + src/mrpro/phantoms/_EllipsePhantom.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8db6ac680..3f8bee29c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ locale = "en-us" [tool.typos.default.extend-words] Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med. +iy = "iy" [tool.typos.files] extend-exclude = ["examples/*.ipynb"] # don't check notebooks because py files have already been checked diff --git a/src/mrpro/phantoms/_EllipsePhantom.py b/src/mrpro/phantoms/_EllipsePhantom.py index 167e5142b..bb9004091 100644 --- a/src/mrpro/phantoms/_EllipsePhantom.py +++ b/src/mrpro/phantoms/_EllipsePhantom.py @@ -91,7 +91,7 @@ def image_space(self, image_dimensions: SpatialDimension[int]) -> torch.Tensor: """ # Calculate image representation of phantom ny, nx = image_dimensions.y, image_dimensions.x - ix, it = torch.meshgrid( + ix, iy = torch.meshgrid( torch.linspace(-nx // 2, nx // 2 - 1, nx), torch.linspace(-ny // 2, ny // 2 - 1, ny), indexing='xy', @@ -101,7 +101,7 @@ def image_space(self, image_dimensions: SpatialDimension[int]) -> torch.Tensor: for ellipse in self.ellipses: in_ellipse = ( (ix / nx - ellipse.center_x) ** 2 / ellipse.radius_x**2 - + (it / ny - ellipse.center_y) ** 2 / ellipse.radius_y**2 + + (iy / ny - ellipse.center_y) ** 2 / ellipse.radius_y**2 ) <= 1 idata += ellipse.intensity * in_ellipse From 4bfcad93cb0e05620600297a32332ab2dcaec914 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 25 May 2024 01:02:20 +0000 Subject: [PATCH 17/19] Update [ghstack-poisoned] --- src/mrpro/data/_MoveDataMixin.py | 45 +++++++++++++++++++++++++------- tests/data/test_movedatamixin.py | 13 ++++----- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/src/mrpro/data/_MoveDataMixin.py b/src/mrpro/data/_MoveDataMixin.py index 53aead69a..a4fac21ae 100644 --- a/src/mrpro/data/_MoveDataMixin.py +++ b/src/mrpro/data/_MoveDataMixin.py @@ -121,7 +121,7 @@ def parse3( ) -> parsedType: return other.device, other.dtype, non_blocking, copy, torch.preserve_format - def parse1( + def parse2( dtype: torch.dtype, non_blocking: bool = False, copy: bool = False, @@ -129,7 +129,7 @@ def parse1( ) -> parsedType: return None, dtype, non_blocking, copy, memory_format - def parse2( + def parse1( device: str | torch.device | int | None = None, dtype: None | torch.dtype = None, non_blocking: bool = False, @@ -138,18 +138,19 @@ def parse2( ) -> parsedType: return device, dtype, non_blocking, copy, memory_format - if args and isinstance(args[0], torch.Tensor) or 'other' in kwargs: - # overload 3 + if args and isinstance(args[0], torch.Tensor) or 'tensor' in kwargs: + # overload 3 ("tensor" specifies the dtype and device) device, dtype, non_blocking, copy, memory_format = parse3(*args, **kwargs) elif args and isinstance(args[0], torch.dtype): - # overload 2 - device, dtype, non_blocking, copy, memory_format = parse1(*args, **kwargs) - else: - # overload 1 + # overload 2 (no device specified, only dtype) device, dtype, non_blocking, copy, memory_format = parse2(*args, **kwargs) + else: + # overload 1 (device and dtype specified) + device, dtype, non_blocking, copy, memory_format = parse1(*args, **kwargs) return self._to(device=device, dtype=dtype, non_blocking=non_blocking, memory_format=memory_format, copy=copy) def _items(self) -> Iterator[tuple[str, Any]]: + """Return an iterator over fields, parameters, buffers, and modules of the object.""" if isinstance(self, DataclassInstance): for field in dataclasses.fields(self): name = field.name @@ -171,7 +172,33 @@ def _to( memo: dict | None = None, ) -> Self: new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + """Move data to device and convert dtype if necessary. + + This method is called by .to(), .cuda(), .cpu(), .double(), and so on. + It should not be called directly. + See .to() for more details. + + Parameters + ---------- + device + The destination device. + dtype + The destination dtype. + non_blocking + If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. + memory_format + The desired memory format of returned tensor. + shared_memory + If True and the target device is CPU, the tensors will reside in shared memory. + Otherwise, the argument has no effect. + copy + If True, the returned tensor will always be a copy, even if the input was already on the correct device. + This will also create new tensors for views + memo + A dictionary to keep track of already converted objects to avoid multiple conversions. + """ if memo is None: memo = {} @@ -289,7 +316,7 @@ def double(self, *, memory_format: torch.memory_format = torch.preserve_format, return self._to(dtype=torch.float64, memory_format=memory_format, copy=copy) def half(self, *, memory_format: torch.memory_format = torch.preserve_format, copy: bool = False) -> Self: - """Convert all float tensors to haf precision. + """Convert all float tensors to half precision. converts float to float16 and complex to complex32 diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 8d24d06c4..733f38f88 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -77,9 +77,10 @@ def _test( assert new_data.dtype == expected_dtype, 'dtype not set correctly' -@pytest.mark.parametrize('dtype', [torch.float64, torch.complex128]) +@pytest.mark.parametrize('dtype', [torch.float32, torch.complex64, torch.float64, torch.complex128]) @pytest.mark.parametrize('copy', [True, False]) -def test_movedatamixin_float64like(copy: bool, dtype: torch.dtype): +def test_movedatamixin_to(copy: bool, dtype: torch.dtype): + """Test MoveDataMixin.to using a nested object.""" original = B() new = original.to(dtype=dtype, copy=copy) @@ -87,8 +88,8 @@ def test_movedatamixin_float64like(copy: bool, dtype: torch.dtype): def test(attribute, expected_dtype): return _test(original, new, attribute, copy, expected_dtype, torch.device('cpu')) - test('floattensor', torch.float64) - test('complextensor', torch.complex128) + test('floattensor', dtype.to_real) + test('complextensor', dtype.to_complex) test('inttensor', torch.int32) test('booltensor', torch.bool) @@ -96,8 +97,8 @@ def test(attribute, expected_dtype): def testchild(attribute, expected_dtype): return _test(original.child, new.child, attribute, copy, expected_dtype, torch.device('cpu')) - testchild('floattensor', torch.float64) - testchild('complextensor', torch.complex128) + testchild('floattensor', dtype.to_real) + testchild('complextensor', dtype.to_complex) testchild('inttensor', torch.int32) testchild('booltensor', torch.bool) From 1374cb2d4f26807dec78dd0de3ce83a7417aca42 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 27 May 2024 17:22:07 +0000 Subject: [PATCH 18/19] Update [ghstack-poisoned] --- tests/data/test_movedatamixin.py | 110 +++++++++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 6 deletions(-) diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 733f38f88..128e3a9cf 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -52,6 +52,7 @@ class B(MoveDataMixin): complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) + doubletensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.float64)) def _test( @@ -72,7 +73,14 @@ def _test( new_data = getattr(new, attribute) if copy: assert new_data is not original_data, 'copy requested but not performed' - assert torch.equal(new_data, original_data.to(device=expected_device, dtype=expected_dtype)) + if torch.is_complex(original_data): + # torch.equal not yet implemented for complex half tensors. + assert torch.equal( + torch.view_as_real(new_data), + torch.view_as_real(original_data.to(device=expected_device, dtype=expected_dtype)), + ) + else: + assert torch.equal(new_data, original_data.to(device=expected_device, dtype=expected_dtype)) assert new_data.device == expected_device, 'device not set correctly' assert new_data.dtype == expected_dtype, 'dtype not set correctly' @@ -88,8 +96,9 @@ def test_movedatamixin_to(copy: bool, dtype: torch.dtype): def test(attribute, expected_dtype): return _test(original, new, attribute, copy, expected_dtype, torch.device('cpu')) - test('floattensor', dtype.to_real) - test('complextensor', dtype.to_complex) + test('floattensor', dtype.to_real()) + test('doubletensor', dtype.to_real()) + test('complextensor', dtype.to_complex()) test('inttensor', torch.int32) test('booltensor', torch.bool) @@ -97,13 +106,102 @@ def test(attribute, expected_dtype): def testchild(attribute, expected_dtype): return _test(original.child, new.child, attribute, copy, expected_dtype, torch.device('cpu')) - testchild('floattensor', dtype.to_real) - testchild('complextensor', dtype.to_complex) + testchild('floattensor', dtype.to_real()) + testchild('complextensor', dtype.to_complex()) testchild('inttensor', torch.int32) testchild('booltensor', torch.bool) # Module attribute - _test(original.child.module, new.child.module, 'weight', copy, torch.float64, torch.device('cpu')) + _test(original.child.module, new.child.module, 'weight', copy, dtype.to_real(), torch.device('cpu')) + + # No-copy required for these + if not copy: + assert original.inttensor is new.inttensor, 'no copy of inttensor required' + assert original.booltensor is new.booltensor, 'no copy of booltensor required' + assert original.child.inttensor is new.child.inttensor, 'no copy of inttensor required' + assert original.child.booltensor is new.child.booltensor, 'no copy of booltensor required' + assert original is not new, 'original and new should not be the same object' + + assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +@pytest.mark.filterwarnings('ignore:ComplexHalf:UserWarning') +@pytest.mark.parametrize( + ('dtype', 'attribute'), [(torch.float16, 'half'), (torch.float32, 'single'), (torch.float64, 'double')] +) +@pytest.mark.parametrize('copy', [True, False]) +def test_movedatamixin_convert(copy: bool, dtype: torch.dtype, attribute: str): + """Test MoveDataMixin.half/double/single using a nested object.""" + original = B() + new = getattr(original, attribute)(copy=copy) + + # Tensor attributes + def test(attribute, expected_dtype): + return _test(original, new, attribute, copy, expected_dtype, torch.device('cpu')) + + test('floattensor', dtype.to_real()) + test('doubletensor', dtype.to_real()) + test('complextensor', dtype.to_complex()) + test('inttensor', torch.int32) + test('booltensor', torch.bool) + + # Attributes of child + def testchild(attribute, expected_dtype): + return _test(original.child, new.child, attribute, copy, expected_dtype, torch.device('cpu')) + + testchild('floattensor', dtype.to_real()) + testchild('complextensor', dtype.to_complex()) + testchild('inttensor', torch.int32) + testchild('booltensor', torch.bool) + + # Module attribute + _test(original.child.module, new.child.module, 'weight', copy, dtype.to_real(), torch.device('cpu')) + + # No-copy required for these + if not copy: + assert original.inttensor is new.inttensor, 'no copy of inttensor required' + assert original.booltensor is new.booltensor, 'no copy of booltensor required' + assert original.child.inttensor is new.child.inttensor, 'no copy of inttensor required' + assert original.child.booltensor is new.child.booltensor, 'no copy of booltensor required' + assert original is not new, 'original and new should not be the same object' + + assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +@pytest.mark.cuda() +@pytest.mark.parametrize('already_moved', [True, False]) +@pytest.mark.parametrize('copy', [True, False]) +def test_movedatamixin_cuda(already_moved: bool, copy: bool): + """Test MoveDataMixin.cuda using a nested object.""" + original = B() + if already_moved: + original = original.cuda(torch.cuda.current_device()) + new = original.cuda(copy=copy) + expected_device = torch.device(torch.cuda.current_device()) + + # Tensor attributes + def test(attribute, expected_dtype): + return _test(original, new, attribute, copy, expected_dtype, expected_device) + + # all tensors should be of the same dtype as before + test('floattensor', torch.float32) + test('doubletensor', torch.float64) + + test('complextensor', torch.complex64) + test('inttensor', torch.int32) + test('booltensor', torch.bool) + + # Attributes of child + def testchild(attribute, expected_dtype): + return _test(original.child, new.child, attribute, copy, expected_dtype, expected_device) + + testchild('floattensor', torch.float32) + testchild('complextensor', torch.complex64) + testchild('inttensor', torch.int32) + testchild('booltensor', torch.bool) + + # Module attribute + _test(original.child.module, new.child.module, 'weight', copy, torch.float32, torch.device('cpu')) # No-copy required for these if not copy: From 4d85384ec7ec104f0f25f5022c2ce758f44f090e Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 27 May 2024 18:35:36 +0000 Subject: [PATCH 19/19] Update [ghstack-poisoned] --- tests/data/test_movedatamixin.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 128e3a9cf..118f08576 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -73,6 +73,9 @@ def _test( new_data = getattr(new, attribute) if copy: assert new_data is not original_data, 'copy requested but not performed' + + assert new_data.device == expected_device, 'device not set correctly' + assert new_data.dtype == expected_dtype, 'dtype not set correctly' if torch.is_complex(original_data): # torch.equal not yet implemented for complex half tensors. assert torch.equal( @@ -81,8 +84,6 @@ def _test( ) else: assert torch.equal(new_data, original_data.to(device=expected_device, dtype=expected_dtype)) - assert new_data.device == expected_device, 'device not set correctly' - assert new_data.dtype == expected_dtype, 'dtype not set correctly' @pytest.mark.parametrize('dtype', [torch.float32, torch.complex64, torch.float64, torch.complex128]) @@ -178,6 +179,7 @@ def test_movedatamixin_cuda(already_moved: bool, copy: bool): original = original.cuda(torch.cuda.current_device()) new = original.cuda(copy=copy) expected_device = torch.device(torch.cuda.current_device()) + assert new.device == expected_device # Tensor attributes def test(attribute, expected_dtype): @@ -201,12 +203,15 @@ def testchild(attribute, expected_dtype): testchild('booltensor', torch.bool) # Module attribute - _test(original.child.module, new.child.module, 'weight', copy, torch.float32, torch.device('cpu')) + _test(original.child.module, new.child.module, 'weight', copy, torch.float32, expected_device) # No-copy required for these - if not copy: + if not copy and already_moved: assert original.inttensor is new.inttensor, 'no copy of inttensor required' assert original.booltensor is new.booltensor, 'no copy of booltensor required' + assert original.floattensor is new.floattensor, 'no copy of floattensor required' + assert original.doubletensor is new.doubletensor, 'no copy of doubletensor required' + assert original.child.complextensor is new.child.complextensor, 'no copy of complextensor required' assert original.child.inttensor is new.child.inttensor, 'no copy of inttensor required' assert original.child.booltensor is new.child.booltensor, 'no copy of booltensor required' assert original is not new, 'original and new should not be the same object'