Skip to content

Commit

Permalink
Update pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
mrdkucher committed Aug 25, 2023
1 parent 51d4c3f commit 2ad1090
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 97 deletions.
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from typing import Any, Dict, Optional, Union, Tuple
from typing import Any
from typing import Dict
from typing import Optional

import numpy as np
import SimpleITK as sitk
import torch

from .random_affine import Affine, RandomAffine
from .random_elastic_deformation import ElasticDeformation, RandomElasticDeformation
from .random_affine import Affine
from .random_affine import RandomAffine
from .random_elastic_deformation import ElasticDeformation
from .random_elastic_deformation import RandomElasticDeformation
from .. import RandomTransform
from ... import SpatialTransform
from ....constants import INTENSITY
from ....constants import TYPE
from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ....typing import TypeRangeFloat
from ....typing import TypeSextetFloat
from ....typing import TypeTripletFloat

TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]


class RandomCombinedAffineElasticDeformation(RandomTransform, SpatialTransform):
Expand Down Expand Up @@ -90,12 +89,29 @@ def get_params(self):
def apply_transform(self, subject: Subject):
affine_params, elastic_params = self.get_params()

scaling_params, rotation_params, translation_params = affine_params
affine_params = {
'scales': scaling_params.tolist(),
'degrees': rotation_params.tolist(),
'translation': translation_params.tolist(),
'center': self.random_affine.center,
'default_pad_value': self.random_affine.default_pad_value,
'image_interpolation': self.random_affine.image_interpolation,
'label_interpolation': self.random_affine.label_interpolation,
'check_shape': self.random_affine.check_shape,
}

elastic_params = {
'control_points': elastic_params,
'max_displacement': self.random_elastic.max_displacement,
'image_interpolation': self.random_elastic.image_interpolation,
'label_interpolation': self.random_elastic.label_interpolation,
}

arguments = {
'affine_first': self.affine_first,
'affine_params': affine_params,
'elastic_params': elastic_params,
'affine_kwargs': self.affine_kwargs,
'elastic_kwargs': self.elastic_kwargs,
}

transform = CombinedAffineElasticDeformation(
Expand Down Expand Up @@ -125,32 +141,25 @@ class CombinedAffineElasticDeformation(SpatialTransform):
def __init__(
self,
affine_first: bool,
affine_params: Tuple[TypeTripletFloat],
elastic_params: np.ndarray,
affine_kwargs: Dict[str, Any],
elastic_kwargs: Dict[str, Any],
affine_params: Dict[str, Any],
elastic_params: Dict[str, Any],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.affine_first = affine_first

self.affine_params = affine_params
self._affine = Affine(
*affine_params**affine_kwargs,
**self.affine_params,
**kwargs,
)
self.elastic_params = elastic_params
self._elastic = ElasticDeformation(
*elastic_params,
**elastic_kwargs,
**self.elastic_params,
**kwargs,
)

self.args_names = [
'affine_params',
'elastic_params',
'affine_first',
'affine_kwargs',
'elastic_kwargs',
]
self.args_names = ['affine_first', 'affine_params', 'elastic_params']

def apply_transform(self, subject: Subject) -> Subject:
if self._affine.check_shape:
Expand Down
Loading

0 comments on commit 2ad1090

Please sign in to comment.