Skip to content

Commit

Permalink
Add interval to compose.py
Browse files Browse the repository at this point in the history
Following to [PR5733](pyg-team#7533)  I've Added the interval argument. @rusty1s With these modifications, both the Compose and ComposeFilters classes accept an interval argument, and it will be passed to the transformed data when composing transforms or applying filters.
  • Loading branch information
mzamini92 authored Sep 30, 2023
1 parent 1e12d41 commit 12cdabb
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions torch_geometric/transforms/compose.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from typing import Callable, List, Union
from typing import Callable, List, Union, Tuple

from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform


class Compose(BaseTransform):
r"""Composes several transforms together.
Args:
transforms (List[Callable]): List of transforms to compose.
interval (Tuple[float, float], optional): A tuple representing the
interval for the transformation. Defaults to (0.0, 1.0).
"""
def __init__(self, transforms: List[Callable]):
def __init__(self, transforms: List[Callable], interval: Tuple[float, float] = (0.0, 1.0)):
self.transforms = transforms
self.interval = interval

def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
# Pass the interval argument to the transformed data
data.interval = self.interval

for transform in self.transforms:
if isinstance(data, (list, tuple)):
data = [transform(d) for d in data]
Expand All @@ -34,14 +39,20 @@ class ComposeFilters:
Args:
filters (List[Callable]): List of filters to compose.
interval (Tuple[float, float], optional): A tuple representing the
interval for the transformation. Defaults to (0.0, 1.0).
"""
def __init__(self, filters: List[Callable]):
def __init__(self, filters: List[Callable], interval: Tuple[float, float] = (0.0, 1.0)):
self.filters = filters
self.interval = interval

def __call__(
self,
data: Union[Data, HeteroData],
) -> bool:
# Pass the interval argument to the transformed data
data.interval = self.interval

for filter_fn in self.filters:
if isinstance(data, (list, tuple)):
if not all([filter_fn(d) for d in data]):
Expand Down

0 comments on commit 12cdabb

Please sign in to comment.