Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Pen)ultimate edge logic fix #335

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 62 additions & 55 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,11 @@ def element_types():


class Edge:
def __init__(
self, src: int, dst: int, image: np.ndarray, is_undirected: bool = False
):
def __init__(self, src: int, dst: int, image: np.ndarray):
"""
Initializes the Edge object with the source, destination,
and image, ensuring directionality and handling self-loops.

Parameters
----------
src : int
Expand All @@ -622,17 +621,15 @@ def __init__(
image : np.ndarray
1D vector of three elements as a ``np.ndarray``.
"""
if is_undirected:
if src > dst:
# Enforce directed order
src, dst, image = dst, src, -image
if src > dst:
# Enforce directed order
src, dst, image = dst, src, -image
if src == dst:
# For self-loops, enforce a canonical form of the image
image = self._canonicalize_image(image)
self.src = src
self.dst = dst
self.image = image
self.is_undirected = is_undirected

@staticmethod
def _canonicalize_image(image: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -686,10 +683,7 @@ def __eq__(self, other) -> bool:
"""
if not isinstance(other, Edge):
return False
if self.is_undirected:
node_eq = self.directed_index == other.directed_index
else:
node_eq = (self.src == other.src) and (self.dst == other.dst)
node_eq = self.directed_index == other.directed_index
return node_eq and np.array_equal(self.image, other.image)

def __str__(self) -> str:
Expand Down Expand Up @@ -854,32 +848,41 @@ def _all_sites_have_neighbors(neighbors):
raise ValueError(
f"No neighbors detected for structure with cutoff {cutoff}; {structure}"
)
keep = set()
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
keep.add(
Edge(
src_idx,
site.index,
np.array(site.image),
is_undirected,
# if we assume undirected edges, apply a filter
if is_undirected:
keep = set()
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
keep.add(
Edge(
src_idx,
site.index,
np.array(site.image),
)
)
)
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# stop adding edges if either src/dst have accumulated enough neighbors
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# stop adding edges if either src/dst have accumulated enough neighbors
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1
# alternatively, just add the edges as is from pymatgen
else:
all_src, all_dst, all_images = [], [], []
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
all_src.append(src_idx)
all_dst.append(site.index)
all_images.append(site.image)
if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]):
raise ValueError(
f"No images or edges to work off for cutoff {cutoff}."
Expand Down Expand Up @@ -963,24 +966,28 @@ def calculate_ase_periodic_shifts(
)
# not really needed but good sanity check
assert np.all(distances <= cutoff_radius)
keep = set()
# only keeps undirected edges that are unique
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src=src, dst=dst, image=image, is_undirected=is_undirected))

all_src, all_dst, all_images = [], [], []
num_atoms = len(atoms)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# obey max_neighbors by not adding any more edges
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1

# in the undirected case, we will filter out
# half of the edges as src/dst == dst/src for a given image
if is_undirected:
keep = set()
# only keeps undirected edges that are unique
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src=src, dst=dst, image=image))

all_src, all_dst, all_images = [], [], []
num_atoms = len(atoms)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# obey max_neighbors by not adding any more edges
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1

frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float()
coords = torch.from_numpy(atoms.positions).float()
Expand Down
Loading