Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to reset child labels in Compound.remove() #1173

Merged
merged 9 commits into from
Apr 4, 2024
35 changes: 32 additions & 3 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def add(

# Add new_part to labels. Does not currently support batch add.
if label is None:
label = "{0}[$]".format(new_child.__class__.__name__)
label = "{0}[$]".format(new_child.name)

if label.endswith("[$]"):
label = label[:-3]
Expand Down Expand Up @@ -1015,13 +1015,15 @@ def add(
"outside of the defined simulation box"
)

def remove(self, objs_to_remove):
def remove(self, objs_to_remove, reset_labels=False):
jaclark5 marked this conversation as resolved.
Show resolved Hide resolved
"""Remove children from the Compound cleanly.

Parameters
----------
objs_to_remove : mb.Compound or list of mb.Compound
The Compound(s) to be removed from self
reset_labels : bool
If True, the Compound labels will be reset
"""
# Preprocessing and validating input type
from mbuild.port import Port
Expand Down Expand Up @@ -1082,9 +1084,36 @@ def _check_if_empty(child):
if self.contains_rigid:
self.root._reorder_rigid_ids()

# Remove ghsot ports
# Remove ghost ports
self._prune_ghost_ports()

# Reorder labels
if reset_labels:
new_labels = OrderedDict()
for child in self.children:
if "Port" in child.name:
label = [
key
for key, x in self.labels.items()
if id(x) == id(child)
][0]
if "port" in label:
label = "{0}[$]".format("port")
jaclark5 marked this conversation as resolved.
Show resolved Hide resolved
else:
label = "{0}[$]".format(child.name)

if label.endswith("[$]"):
label = label[:-3]
if label not in new_labels:
new_labels[label] = []
label_pattern = label + "[{}]"

count = len(new_labels[label])
new_labels[label].append(child)
label = label_pattern.format(count)
new_labels[label] = child
self.labels = new_labels

def _prune_ghost_ports(self):
"""Worker for remove(). Remove all ports whose anchor has been deleted."""
all_ports_list = list(self.all_ports())
Expand Down
41 changes: 41 additions & 0 deletions mbuild/tests/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,47 @@ def test_remove(self, ethane):
assert len(ethane5.children[0].children) == 6 # 3 hydrogens + 3 ports
assert len(ethane5.children) == 1

# Test to reset labels after hydrogens
ethane6 = mb.clone(ethane)
ethane6.flatten()
hydrogens = ethane6.particles_by_name("H")
ethane6.remove(hydrogens)

assert list(ethane6.labels.keys()) == [
"methyl1",
"methyl2",
"C",
"C[0]",
"H",
"C[1]",
"port",
"port[1]",
"port[3]",
"port[5]",
"port[7]",
"port[9]",
"port[11]",
]

ethane7 = mb.clone(ethane)
ethane7.flatten()
hydrogens = ethane7.particles_by_name("H")
ethane7.remove(hydrogens, reset_labels=True)

print(list(ethane7.labels.keys()))
jaclark5 marked this conversation as resolved.
Show resolved Hide resolved
assert list(ethane7.labels.keys()) == [
"C",
"C[0]",
"C[1]",
"port",
"port[0]",
"port[1]",
"port[2]",
"port[3]",
"port[4]",
"port[5]",
]

def test_remove_many(self, ethane):
ethane.remove([ethane.children[0], ethane.children[1]])

Expand Down
Loading