diff --git a/src/deepali/data/flow.py b/src/deepali/data/flow.py index 16f6d63..3372178 100644 --- a/src/deepali/data/flow.py +++ b/src/deepali/data/flow.py @@ -142,8 +142,6 @@ def _torch_function_result( @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if func == F.grid_sample: - raise ValueError("Argument of F.grid_sample() must be a batch, not a single image") if kwargs is None: kwargs = {} data = Tensor.__torch_function__(func, (Tensor,), args, kwargs)