Skip to content

Commit

Permalink
Add mode: Literal["w", "a", "wt", "at"] = "w" keyword to `CifWriter…
Browse files Browse the repository at this point in the history
….write_file()` (#3399)

* add mode: Literal["w", "a", "wt", "at"] = "w" keyword to CifWriter.write_file()

* add test_cif_writer_write_file

* cif to CIF in comments
  • Loading branch information
janosh authored Oct 11, 2023
1 parent c818aa7 commit 63d7605
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pymatgen/analysis/chemenv/utils/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def compute_environments(chemenv_configuration):
input_source = test
if source_type == "cif":
if not found:
input_source = input("Enter path to cif file : ")
input_source = input("Enter path to CIF file : ")
parser = CifParser(input_source)
structure = parser.get_structures()[0]
elif source_type == "mp":
Expand Down
19 changes: 6 additions & 13 deletions pymatgen/command_line/mcsqs_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def run_mcsqs(
instances = os.cpu_count()

original_directory = os.getcwd()
if not directory:
directory = tempfile.mkdtemp()
directory = directory or tempfile.mkdtemp()
os.chdir(directory)

if isinstance(scaling, (int, float)):
Expand All @@ -109,13 +108,7 @@ def run_mcsqs(
process.communicate()

# Generate SQS structures
add_ons = [
f"-T {temperature}",
f"-wr {wr}",
f"-wn {wn}",
f"-wd {wd}",
f"-tol {tol}",
]
add_ons = [f"-T {temperature}", f"-wr {wr}", f"-wn {wn}", f"-wd {wd}", f"-tol {tol}"]

mcsqs_find_sqs_processes = []
if instances and instances > 1:
Expand Down Expand Up @@ -181,7 +174,7 @@ def _parse_sqs_path(path) -> Sqs:
# detected instances will be 0 if mcsqs was run in series, or number of instances
detected_instances = len(list(path.glob("bestsqs*[0-9]*.out")))

# Convert best SQS structure to cif file and pymatgen Structure
# Convert best SQS structure to CIF file and pymatgen Structure
with Popen("str2cif < bestsqs.out > bestsqs.cif", shell=True, cwd=path) as p:
p.communicate()

Expand All @@ -198,7 +191,7 @@ def _parse_sqs_path(path) -> Sqs:
objective_function = float(objective_function_str) if objective_function_str != "Perfect_match" else "Perfect_match"

# Get all SQS structures and objective functions
allsqs = []
all_sqs = []

for i in range(detected_instances):
sqs_out = f"bestsqs{i + 1}.out"
Expand All @@ -213,14 +206,14 @@ def _parse_sqs_path(path) -> Sqs:
objective_function_str = lines[-1].split("=")[-1].strip()
obj: float | str
obj = float(objective_function_str) if objective_function_str != "Perfect_match" else "Perfect_match"
allsqs.append({"structure": sqs, "objective_function": obj})
all_sqs.append({"structure": sqs, "objective_function": obj})

clusters = _parse_clusters(path / "clusters.out")

return Sqs(
bestsqs=best_sqs,
objective_function=objective_function,
allsqs=allsqs,
allsqs=all_sqs,
directory=str(path.resolve()),
clusters=clusters,
)
Expand Down
57 changes: 31 additions & 26 deletions pymatgen/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, data, loops, header):
"""
self.loops = loops
self.data = data
# AJ says: CIF Block names cannot be more than 75 characters or you
# AJ (@computron) says: CIF Block names cannot be more than 75 characters or you
# get an Exception
self.header = header[:74]

Expand Down Expand Up @@ -246,26 +246,27 @@ def from_string(cls, *args, **kwargs):
return cls.from_str(*args, **kwargs)

@classmethod
def from_str(cls, string):
"""
Reads CifFile from a string.
def from_str(cls, string) -> CifFile:
"""Reads CifFile from a string.
:param string: String representation.
Returns:
CifFile
"""
dct = {}
for x in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]:

for block_str in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]:
# Skip over Cif block that contains powder diffraction data.
# Some elements in this block were missing from CIF files in
# Springer materials/Pauling file DBs.
# This block anyway does not contain any structure information, and
# This block does not contain any structure information anyway, and
# CifParser was also not parsing it.
if "powder_pattern" in re.split(r"\n", x, maxsplit=1)[0]:
if "powder_pattern" in re.split(r"\n", block_str, maxsplit=1)[0]:
continue
c = CifBlock.from_str("data_" + x)
dct[c.header] = c
block = CifBlock.from_str("data_" + block_str)
dct[block.header] = block

return cls(dct, string)

@classmethod
Expand Down Expand Up @@ -675,7 +676,7 @@ def get_symops(self, data):
operations are parsed. If the symops are not present, the space
group symbol is parsed, and symops are generated.
"""
symops = []
sym_ops = []
for symmetry_label in [
"_symmetry_equiv_pos_as_xyz",
"_symmetry_equiv_pos_as_xyz_",
Expand All @@ -690,11 +691,11 @@ def get_symops(self, data):
self.warnings.append(msg)
xyz = [xyz]
try:
symops = [SymmOp.from_xyz_str(s) for s in xyz]
sym_ops = [SymmOp.from_xyz_str(s) for s in xyz]
break
except ValueError:
continue
if not symops:
if not sym_ops:
# Try to parse symbol
for symmetry_label in [
"_symmetry_space_group_name_H-M",
Expand All @@ -718,7 +719,7 @@ def get_symops(self, data):
try:
spg = space_groups.get(sg)
if spg:
symops = SpaceGroup(spg).symmetry_ops
sym_ops = SpaceGroup(spg).symmetry_ops
msg = msg_template.format(symmetry_label)
warnings.warn(msg)
self.warnings.append(msg)
Expand All @@ -734,17 +735,17 @@ def get_symops(self, data):
for d in cod_data:
if sg == re.sub(r"\s+", "", d["hermann_mauguin"]):
xyz = d["symops"]
symops = [SymmOp.from_xyz_str(s) for s in xyz]
sym_ops = [SymmOp.from_xyz_str(s) for s in xyz]
msg = msg_template.format(symmetry_label)
warnings.warn(msg)
self.warnings.append(msg)
break
except Exception:
continue

if symops:
if sym_ops:
break
if not symops:
if not sym_ops:
# Try to parse International number
for symmetry_label in [
"_space_group_IT_number",
Expand All @@ -755,18 +756,18 @@ def get_symops(self, data):
if data.data.get(symmetry_label):
try:
i = int(str2float(data.data.get(symmetry_label)))
symops = SpaceGroup.from_int_number(i).symmetry_ops
sym_ops = SpaceGroup.from_int_number(i).symmetry_ops
break
except ValueError:
continue

if not symops:
if not sym_ops:
msg = "No _symmetry_equiv_pos_as_xyz type key found. Defaulting to P1."
warnings.warn(msg)
self.warnings.append(msg)
symops = [SymmOp.from_xyz_str(s) for s in ["x", "y", "z"]]
sym_ops = [SymmOp.from_xyz_str(s) for s in ["x", "y", "z"]]

return symops
return sym_ops

def get_magsymops(self, data):
"""
Expand Down Expand Up @@ -1165,6 +1166,8 @@ def get_structures(
Returns:
list[Structure]: All structures in CIF file.
"""
print(len(self._cif.data))

if not check_occu: # added in https://github.com/materialsproject/pymatgen/pull/2836
warnings.warn("Structures with unphysical site occupancies are not compatible with many pymatgen features.")
if primitive and symmetrized:
Expand Down Expand Up @@ -1478,18 +1481,20 @@ def __init__(
self._cf = CifFile(dct)

@property
def ciffile(self):
def cif_file(self):
"""Returns: CifFile associated with the CifWriter."""
return self._cf

def __str__(self):
"""Returns the cif as a string."""
"""Returns the CIF as a string."""
return str(self._cf)

def write_file(self, filename):
"""Write the cif file."""
with zopen(filename, "wt") as f:
f.write(str(self))
def write_file(self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w") -> None:
"""Write the CIF file."""
with zopen(filename, mode=mode) as file:
file.write(str(self))
if mode in ["a", "at"]:
file.write("\n\n")


def str2float(text):
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/feff/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def write_input(self, output_dir=".", make_dir_if_not_present=True):
with open(f"{output_dir}/feff.inp", "w") as f:
f.write(feff_input)

# write the structure to cif file
# write the structure to CIF file
if "ATOMS" not in feff:
self.atoms.struct.to(fmt="cif", filename=os.path.join(output_dir, feff["PARAMETERS"]["CIF"]))

Expand Down
2 changes: 1 addition & 1 deletion tests/files/.pytest-split-durations
Original file line number Diff line number Diff line change
Expand Up @@ -1895,7 +1895,7 @@
"tests/io/test_cif.py::TestCifIO::test_replacing_finite_precision_frac_coords": 0.01487579196691513,
"tests/io/test_cif.py::TestCifIO::test_site_labels": 0.022716625011526048,
"tests/io/test_cif.py::TestCifIO::test_site_symbol_preference": 0.013636041956488043,
"tests/io/test_cif.py::TestCifIO::test_specie_cifwriter": 0.0027243339573033154,
"tests/io/test_cif.py::TestCifIO::test_specie_cif_writer": 0.0027243339573033154,
"tests/io/test_cif.py::TestCifIO::test_symmetrized": 0.11127212399151176,
"tests/io/test_cif.py::TestMagCif::test_bibtex": 0.0052501659956760705,
"tests/io/test_cif.py::TestMagCif::test_get_structures": 0.0418014990282245,
Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_atat.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_mcsqs_export(self):
assert Mcsqs(struct).to_str() == ref_string

def test_mcsqs_cif_nacl(self):
# cif file from str2cif (utility distributed with atat)
# CIF file from str2cif (utility distributed with atat)
struc_from_cif = Structure.from_file(f"{test_dir}/bestsqs_nacl.cif")

# output file directly from mcsqs
Expand All @@ -100,7 +100,7 @@ def test_mcsqs_cif_nacl(self):
)

def test_mcsqs_cif_pzt(self):
# cif file from str2cif (utility distributed with atat)
# CIF file from str2cif (utility distributed with atat)
struc_from_cif = Structure.from_file(f"{test_dir}/bestsqs_pzt.cif")

# output file directly from mcsqs
Expand Down
28 changes: 22 additions & 6 deletions tests/io/test_cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def test_cif_parser_springer_pauling(self):
assert struct.formula == "Zn1.29 Fe0.69 As2 Pb1.02 O8"

def test_cif_parser_cod(self):
"""Parsing problematic cif files from the COD database."""
"""Parsing problematic CIF files from the COD database."""
# Symbol in capital letters
parser = CifParser(f"{TEST_FILES_DIR}/Cod_2100513.cif")
for struct in parser.get_structures():
Expand Down Expand Up @@ -491,11 +491,11 @@ def test_symmetrized(self):
# test angle tolerance.
struct = Structure.from_file(f"{TEST_FILES_DIR}/LiFePO4.cif")
writer = CifWriter(struct, symprec=0.1, angle_tolerance=0)
d = next(iter(writer.ciffile.data.values()))
d = next(iter(writer.cif_file.data.values()))
assert d["_symmetry_Int_Tables_number"] == 14
struct = Structure.from_file(f"{TEST_FILES_DIR}/LiFePO4.cif")
writer = CifWriter(struct, symprec=0.1, angle_tolerance=2)
d = next(iter(writer.ciffile.data.values()))
d = next(iter(writer.cif_file.data.values()))
assert d["_symmetry_Int_Tables_number"] == 62

def test_disordered(self):
Expand Down Expand Up @@ -554,7 +554,7 @@ def test_cif_writer_without_refinement(self):
same_si2 = CifParser.from_str(cif_str).get_structures()[0]
assert len(si2) == len(same_si2)

def test_specie_cifwriter(self):
def test_specie_cif_writer(self):
si4 = Species("Si", 4)
si3 = Species("Si", 3)
n = DummySpecies("X", -3)
Expand Down Expand Up @@ -620,9 +620,9 @@ def test_primes(self):

def test_missing_atom_site_type_with_oxi_states(self):
parser = CifParser(f"{TEST_FILES_DIR}/P24Ru4H252C296S24N16.cif")
c = Composition({"S0+": 24, "Ru0+": 4, "H0+": 252, "C0+": 296, "N0+": 16, "P0+": 24})
comp = Composition({"S0+": 24, "Ru0+": 4, "H0+": 252, "C0+": 296, "N0+": 16, "P0+": 24})
for struct in parser.get_structures(primitive=False):
assert struct.composition == c
assert struct.composition == comp

def test_no_coords_or_species(self):
string = """#generated using pymatgen
Expand Down Expand Up @@ -855,6 +855,22 @@ def test_no_check_occu(self):
structs = parser.get_structures(primitive=False, check_occu=False)[0]
assert structs[0].species.as_dict()["Te"] == 1.5

def test_cif_writer_write_file(self):
struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR")
out_path = f"{self.tmp_path}/test.cif"
CifWriter(struct1).write_file(out_path)
read_structs = CifParser(out_path).get_structures()
assert len(read_structs) == 1
assert struct1.matches(read_structs[0])

# test write_file append mode='a'
struct2 = Structure.from_file(f"{TEST_FILES_DIR}/Graphite.cif")
CifWriter(struct2).write_file(out_path, mode="a")

read_structs = CifParser(out_path).get_structures()
assert len(read_structs) == 2
assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"]


class TestMagCif(PymatgenTest):
def setUp(self):
Expand Down

0 comments on commit 63d7605

Please sign in to comment.