diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index a47b25c6..46e68758 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -607,12 +607,11 @@ def element_types(): class Edge: - def __init__( - self, src: int, dst: int, image: np.ndarray, is_undirected: bool = False - ): + def __init__(self, src: int, dst: int, image: np.ndarray): """ Initializes the Edge object with the source, destination, and image, ensuring directionality and handling self-loops. + Parameters ---------- src : int @@ -622,17 +621,15 @@ def __init__( image : np.ndarray 1D vector of three elements as a ``np.ndarray``. """ - if is_undirected: - if src > dst: - # Enforce directed order - src, dst, image = dst, src, -image + if src > dst: + # Enforce directed order + src, dst, image = dst, src, -image if src == dst: # For self-loops, enforce a canonical form of the image image = self._canonicalize_image(image) self.src = src self.dst = dst self.image = image - self.is_undirected = is_undirected @staticmethod def _canonicalize_image(image: np.ndarray) -> np.ndarray: @@ -686,10 +683,7 @@ def __eq__(self, other) -> bool: """ if not isinstance(other, Edge): return False - if self.is_undirected: - node_eq = self.directed_index == other.directed_index - else: - node_eq = (self.src == other.src) and (self.dst == other.dst) + node_eq = self.directed_index == other.directed_index return node_eq and np.array_equal(self.image, other.image) def __str__(self) -> str: @@ -854,32 +848,41 @@ def _all_sites_have_neighbors(neighbors): raise ValueError( f"No neighbors detected for structure with cutoff {cutoff}; {structure}" ) - keep = set() - # only keeps undirected edges that are unique through set - for src_idx, dst_sites in enumerate(neighbors): - for site in dst_sites: - keep.add( - Edge( - src_idx, - site.index, - np.array(site.image), - is_undirected, + # if we assume undirected edges, apply a filter + if is_undirected: + keep = set() + # only keeps undirected edges that are unique through set + for src_idx, dst_sites in enumerate(neighbors): + for site in dst_sites: + keep.add( + Edge( + src_idx, + site.index, + np.array(site.image), + ) ) - ) - # now only keep the edges after the first loop - all_src, all_dst, all_images = [], [], [] - num_atoms = len(structure.atomic_numbers) - counter = {index: 0 for index in range(num_atoms)} - for edge in keep: - # stop adding edges if either src/dst have accumulated enough neighbors - if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors: - pass - else: - all_src.append(edge.src) - all_dst.append(edge.dst) - all_images.append(edge.image) - counter[edge.src] += 1 - counter[edge.dst] += 1 + # now only keep the edges after the first loop + all_src, all_dst, all_images = [], [], [] + num_atoms = len(structure.atomic_numbers) + counter = {index: 0 for index in range(num_atoms)} + for edge in keep: + # stop adding edges if either src/dst have accumulated enough neighbors + if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors: + pass + else: + all_src.append(edge.src) + all_dst.append(edge.dst) + all_images.append(edge.image) + counter[edge.src] += 1 + counter[edge.dst] += 1 + # alternatively, just add the edges as is from pymatgen + else: + all_src, all_dst, all_images = [], [], [] + for src_idx, dst_sites in enumerate(neighbors): + for site in dst_sites: + all_src.append(src_idx) + all_dst.append(site.index) + all_images.append(site.image) if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]): raise ValueError( f"No images or edges to work off for cutoff {cutoff}." @@ -963,24 +966,28 @@ def calculate_ase_periodic_shifts( ) # not really needed but good sanity check assert np.all(distances <= cutoff_radius) - keep = set() - # only keeps undirected edges that are unique - for src, dst, image in zip(all_src, all_dst, all_images): - keep.add(Edge(src=src, dst=dst, image=image, is_undirected=is_undirected)) - - all_src, all_dst, all_images = [], [], [] - num_atoms = len(atoms) - counter = {index: 0 for index in range(num_atoms)} - for edge in keep: - # obey max_neighbors by not adding any more edges - if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors: - pass - else: - all_src.append(edge.src) - all_dst.append(edge.dst) - all_images.append(edge.image) - counter[edge.src] += 1 - counter[edge.dst] += 1 + + # in the undirected case, we will filter out + # half of the edges as src/dst == dst/src for a given image + if is_undirected: + keep = set() + # only keeps undirected edges that are unique + for src, dst, image in zip(all_src, all_dst, all_images): + keep.add(Edge(src=src, dst=dst, image=image)) + + all_src, all_dst, all_images = [], [], [] + num_atoms = len(atoms) + counter = {index: 0 for index in range(num_atoms)} + for edge in keep: + # obey max_neighbors by not adding any more edges + if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors: + pass + else: + all_src.append(edge.src) + all_dst.append(edge.dst) + all_images.append(edge.image) + counter[edge.src] += 1 + counter[edge.dst] += 1 frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float() coords = torch.from_numpy(atoms.positions).float()