Skip to content

Commit

Permalink
AseAtomsAdaptor: Retain tags property when interconverting `Atoms…
Browse files Browse the repository at this point in the history
…` and `Structure`/`Molecule` (#3151)

* add support for tags
  • Loading branch information
Andrew-S-Rosen authored Jul 22, 2023
1 parent 9122d21 commit 5d6f566
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
19 changes: 7 additions & 12 deletions pymatgen/core/tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,18 +1420,13 @@ def test_calculate_chgnet(self):
assert preds["magmoms"] == approx([0.00262399, 0.00262396], abs=1e-5)
assert np.linalg.norm(preds["forces"]) == approx(1.998941843e-5, abs=1e-3)
assert not hasattr(calculator, "dynamics"), "static calculation should not have dynamics"
assert {*calculator.__dict__} == {
"atoms",
"results",
"parameters",
"_directory",
"prefix",
"name",
"get_spin_polarized",
"device",
"model",
"stress_weight",
}
assert "atoms" in calculator.__dict__
assert "results" in calculator.__dict__
assert "parameters" in calculator.__dict__
assert "get_spin_polarized" in calculator.__dict__
assert "device" in calculator.__dict__
assert "model" in calculator.__dict__
assert "stress_weight" in calculator.__dict__
assert len(calculator.parameters) == 0
assert isinstance(calculator.atoms, Atoms)
assert len(calculator.atoms) == len(struct)
Expand Down
8 changes: 8 additions & 0 deletions pymatgen/io/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms:

atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs)

if "tags" in structure.site_properties:
atoms.set_tags(structure.site_properties["tags"])

# Set the site magmoms in the ASE Atoms object
# Note: ASE distinguishes between initial and converged
# magnetic moment site properties, whereas pymatgen does not. Therefore, we
Expand Down Expand Up @@ -181,6 +184,9 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs)
positions = atoms.get_positions()
lattice = atoms.get_cell()

# Get the tags
tags = atoms.get_tags() if atoms.has("tags") else None

# Get the (final) site magmoms and charges from the ASE Atoms object.
if getattr(atoms, "calc", None) is not None and getattr(atoms.calc, "results", None) is not None:
charges = atoms.calc.results.get("charges")
Expand Down Expand Up @@ -247,6 +253,8 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs)
structure.add_site_property("magmom", initial_magmoms)
if sel_dyn is not None and ~np.all(sel_dyn):
structure.add_site_property("selective_dynamics", sel_dyn)
if tags is not None:
structure.add_site_property("tags", tags)

# Add oxidation states by site
if oxi_states is not None:
Expand Down
3 changes: 3 additions & 0 deletions pymatgen/io/tests/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_get_atoms_from_structure(self):
assert atoms.get_pbc().all()
assert atoms.get_chemical_symbols() == [s.species_string for s in structure]
assert not atoms.has("initial_magmoms")
assert not atoms.has("initial_charges")
assert atoms.calc is None

p = Poscar.from_file(os.path.join(PymatgenTest.TEST_FILES_DIR, "POSCAR"))
Expand Down Expand Up @@ -248,6 +249,7 @@ def test_back_forth(self):
# Atoms --> Structure --> Atoms --> Structure
atoms = read(os.path.join(PymatgenTest.TEST_FILES_DIR, "OUTCAR"))
atoms.info = {"test": "hi"}
atoms.set_tags([1] * len(atoms))
atoms.set_constraint(FixAtoms(mask=[True] * len(atoms)))
atoms.set_initial_charges([1.0] * len(atoms))
atoms.set_initial_magnetic_moments([2.0] * len(atoms))
Expand Down Expand Up @@ -281,6 +283,7 @@ def test_back_forth(self):
atoms.set_initial_charges([1.0] * len(atoms))
atoms.set_initial_magnetic_moments([2.0] * len(atoms))
atoms.set_array("prop", np.array([3.0] * len(atoms)))
atoms.set_tags([1] * len(atoms))
molecule = aio.AseAtomsAdaptor.get_molecule(atoms)
atoms_back = aio.AseAtomsAdaptor.get_atoms(molecule)
molecule_back = aio.AseAtomsAdaptor.get_molecule(atoms_back)
Expand Down

0 comments on commit 5d6f566

Please sign in to comment.