From 12cdabb65b27a676502bbb797d3976d9877a8715 Mon Sep 17 00:00:00 2001 From: Mohamad Zamini <32536264+mzamini92@users.noreply.github.com> Date: Fri, 29 Sep 2023 18:06:57 -0600 Subject: [PATCH] Add interval to compose.py Following to [PR5733](https://github.com/pyg-team/pytorch_geometric/pull/7533) I've Added the interval argument. @rusty1s With these modifications, both the Compose and ComposeFilters classes accept an interval argument, and it will be passed to the transformed data when composing transforms or applying filters. --- torch_geometric/transforms/compose.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/torch_geometric/transforms/compose.py b/torch_geometric/transforms/compose.py index 128dfc773493..19adf01ad236 100644 --- a/torch_geometric/transforms/compose.py +++ b/torch_geometric/transforms/compose.py @@ -1,22 +1,27 @@ -from typing import Callable, List, Union +from typing import Callable, List, Union, Tuple from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import BaseTransform - class Compose(BaseTransform): r"""Composes several transforms together. Args: transforms (List[Callable]): List of transforms to compose. + interval (Tuple[float, float], optional): A tuple representing the + interval for the transformation. Defaults to (0.0, 1.0). """ - def __init__(self, transforms: List[Callable]): + def __init__(self, transforms: List[Callable], interval: Tuple[float, float] = (0.0, 1.0)): self.transforms = transforms + self.interval = interval def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: + # Pass the interval argument to the transformed data + data.interval = self.interval + for transform in self.transforms: if isinstance(data, (list, tuple)): data = [transform(d) for d in data] @@ -34,14 +39,20 @@ class ComposeFilters: Args: filters (List[Callable]): List of filters to compose. + interval (Tuple[float, float], optional): A tuple representing the + interval for the transformation. Defaults to (0.0, 1.0). """ - def __init__(self, filters: List[Callable]): + def __init__(self, filters: List[Callable], interval: Tuple[float, float] = (0.0, 1.0)): self.filters = filters + self.interval = interval def __call__( self, data: Union[Data, HeteroData], ) -> bool: + # Pass the interval argument to the transformed data + data.interval = self.interval + for filter_fn in self.filters: if isinstance(data, (list, tuple)): if not all([filter_fn(d) for d in data]):