Skip to content

Commit

Permalink
Remove KNNGraphNoPE, add distance setting for KNNGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
pweigel committed Jan 6, 2025
1 parent aa70dc0 commit 417ad3b
Showing 1 changed file with 7 additions and 63 deletions.
70 changes: 7 additions & 63 deletions src/graphnet/models/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
seed: Optional[Union[int, Generator]] = None,
nb_nearest_neighbours: int = 8,
columns: List[int] = [0, 1, 2],
distance_as_edge_feature: bool = False,
**kwargs: Any,
) -> None:
"""Construct k-nn graph representation.
Expand All @@ -50,12 +51,17 @@ def __init__(
Defaults to 8.
columns: node feature columns used for distance calculation.
Defaults to [0, 1, 2].
distance_as_edge_feature: Add edge distances as an edge feature.
Defaults to False.
"""
# Base class constructor
edge_definition = (
KNNDistanceEdges if distance_as_edge_feature else KNNEdges
)
super().__init__(
detector=detector,
node_definition=node_definition or NodesAsPulses(),
edge_definition=KNNEdges(
edge_definition=edge_definition(
nb_nearest_neighbours=nb_nearest_neighbours,
columns=columns,
),
Expand Down Expand Up @@ -261,65 +267,3 @@ def forward( # type: ignore
ksteps=ksteps, edge_index=graph.edge_index, edge_weight=None
)
return graph


class KNNGraphNoPE(GraphDefinition):
"""KNN Graph with edge distances and no positional encoding."""

def __init__(
self,
detector: Detector,
node_definition: Optional[NodeDefinition] = None,
edge_definition: Optional[EdgeDefinition] = None,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
nb_nearest_neighbours: int = 8,
columns: List[int] = [0, 1, 2],
**kwargs: Any,
) -> None:
"""Construct k-nn graph representation.
Args:
detector: Detector that represents your data.
node_definition: Definition of nodes in the graph.
edge_definition: Definition of edges in the graph.
input_feature_names: Name of input feature columns.
dtype: data type for node features.
perturbation_dict: Dictionary mapping a feature name to a standard
deviation according to which the values for this
feature should be randomly perturbed. Defaults
to None.
seed: seed or Generator used to randomly sample perturbations.
Defaults to None.
nb_nearest_neighbours: Number of edges for each node.
Defaults to 8.
columns: node feature columns used for distance calculation.
Defaults to [0, 1, 2].
"""
# Base class constructor
super().__init__(
detector=detector,
node_definition=node_definition or NodesAsPulses(),
edge_definition=edge_definition
or KNNDistanceEdges(
nb_nearest_neighbours=nb_nearest_neighbours,
columns=columns,
),
dtype=dtype,
input_feature_names=input_feature_names,
perturbation_dict=perturbation_dict,
seed=seed,
**kwargs,
)

def forward( # type: ignore
self,
input_features: np.ndarray,
input_feature_names: List[str],
**kwargs,
) -> Data:
"""Forward pass."""
graph = super().forward(input_features, input_feature_names, **kwargs)
return graph

0 comments on commit 417ad3b

Please sign in to comment.