diff --git a/hydragnn/models/EGCLStack.py b/hydragnn/models/EGCLStack.py index 15fd0c348..639196555 100644 --- a/hydragnn/models/EGCLStack.py +++ b/hydragnn/models/EGCLStack.py @@ -15,6 +15,8 @@ from torch_geometric.nn import Sequential from .Base import Base +from ..utils import unsorted_segment_mean + class EGCLStack(Base): def __init__( @@ -241,13 +243,3 @@ def unsorted_segment_sum(data, segment_ids, num_segments): segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) result.scatter_add_(0, segment_ids, data) return result - - -def unsorted_segment_mean(data, segment_ids, num_segments): - result_shape = (num_segments, data.size(1)) - segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) - result = data.new_full(result_shape, 0) # Init empty result tensor. - count = data.new_full(result_shape, 0) - result.scatter_add_(0, segment_ids, data) - count.scatter_add_(0, segment_ids, torch.ones_like(data)) - return result / count.clamp(min=1) diff --git a/hydragnn/models/SCFStack.py b/hydragnn/models/SCFStack.py index 6b953d031..7c67cffd5 100644 --- a/hydragnn/models/SCFStack.py +++ b/hydragnn/models/SCFStack.py @@ -26,6 +26,8 @@ from .Base import Base +from ..utils import unsorted_segment_mean + class SCFStack(Base): def __init__( @@ -219,13 +221,3 @@ def coord2radial(self, edge_index, coord): coord_diff = coord_diff / (norm) return radial, coord_diff - - -def unsorted_segment_mean(data, segment_ids, num_segments): - result_shape = (num_segments, data.size(1)) - segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) - result = data.new_full(result_shape, 0) # Init empty result tensor. - count = data.new_full(result_shape, 0) - result.scatter_add_(0, segment_ids, data) - count.scatter_add_(0, segment_ids, torch.ones_like(data)) - return result / count.clamp(min=1) diff --git a/hydragnn/utils/__init__.py b/hydragnn/utils/__init__.py index 6dcded374..2309af6fc 100644 --- a/hydragnn/utils/__init__.py +++ b/hydragnn/utils/__init__.py @@ -14,6 +14,7 @@ from .model import ( save_model, get_summary_writer, + unsorted_segment_mean, load_existing_model, load_existing_model_config, loss_function_selection, diff --git a/hydragnn/utils/model.py b/hydragnn/utils/model.py index b34815d78..bc01eca33 100644 --- a/hydragnn/utils/model.py +++ b/hydragnn/utils/model.py @@ -144,6 +144,16 @@ def calculate_PNA_degree_mpi(loader, max_neighbours): return torch.tensor(deg) +def unsorted_segment_mean(data, segment_ids, num_segments): + result_shape = (num_segments, data.size(1)) + segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) + result = data.new_full(result_shape, 0) # Init empty result tensor. + count = data.new_full(result_shape, 0) + result.scatter_add_(0, segment_ids, data) + count.scatter_add_(0, segment_ids, torch.ones_like(data)) + return result / count.clamp(min=1) + + def print_model(model): """print model's parameter size layer by layer""" num_params = 0