diff --git a/torch_geometric/loader/cache.py b/torch_geometric/loader/cache.py index 581d2882747b..e9e5b91a55a8 100644 --- a/torch_geometric/loader/cache.py +++ b/torch_geometric/loader/cache.py @@ -28,9 +28,9 @@ def __init__(self, loader: DataLoader, device: torch.device): # register default hooks self.register_attr_hooks([ - lambda b: b.n_id, - lambda b: b.adj_t if hasattr(b, 'adj_t') else b.edge_index, - lambda b: b.batch_size]) + lambda b: b.n_id, lambda b: b.adj_t + if hasattr(b, 'adj_t') else b.edge_index, lambda b: b.batch_size + ]) def _check_if_reg_is_open(self) -> None: if not self.hook_reg_open: