diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..ead577ba --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,45 @@ +# Pull Request Template + +## Description + +Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. + +Fixes # (issue) + +## Type of change + +Please delete options that are not relevant. + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] This change requires a documentation update + +## How Has This Been Tested? + +Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration + +- [ ] Test A +- [ ] Test B + +**Test Configuration**: +* Python version: +* Operating System: + + +## Reviewers + +@mention individuals who you specifically want to involve in the discussion for this pull request and mention why they are needed in the dicussion/why they are needed to review the pull request. + + +## Checklist: + +- [ ] My code follows the style guidelines of this project +- [ ] I have performed a self-review of my own code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes +- [ ] Any dependent changes have been merged and published in downstream modules +- [ ] I have checked my code and corrected any misspellings \ No newline at end of file diff --git a/README.md b/README.md index 4a2a4ddc..ff871816 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,14 @@ [![Documentation Status](https://readthedocs.org/projects/smact/badge/?version=latest)](http://smact.readthedocs.org/en/latest/?badge=latest) [![made-with-python](https://img.shields.io/badge/Made%20with-Python-1f425f.svg)](https://www.python.org/) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) -[![Build Status](https://travis-ci.org/WMD-group/SMACT.svg?branch=master)](https://travis-ci.org/WMD-group/SMACT) [![DOI](http://joss.theoj.org/papers/10.21105/joss.01361/status.svg)](https://doi.org/10.21105/joss.01361) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![PyPi](https://img.shields.io/pypi/v/smact)](https://pypi.org/project/SMACT/) +[![GitHub issues](https://img.shields.io/github/issues-raw/WMD-Group/SMACT)](https://github.com/WMD-group/SMACT/issues) +![dependencies](https://img.shields.io/librariesio/release/pypi/smact) +[![CI Status](https://github.com/WMD-group/SMACT/actions/workflows/ci.yml/badge.svg)](https://github.com/WMD-group/SMACT/actions/workflows/ci.yml) +![python version](https://img.shields.io/pypi/pyversions/smact) + SMACT ===== diff --git a/docs/conf.py b/docs/conf.py index 37918f99..aa8a35e2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -381,5 +381,6 @@ def __getattr__(cls, name): "pymatgen.util", "pymatgen.util.plotting", "pymatgen.analysis.structure_prediction", + "pymatgen.transformations.standard_transformations", ] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) diff --git a/docs/examples.rst b/docs/examples.rst index 79942a12..86fe445a 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -108,6 +108,7 @@ We can look for neutral combos. .. code:: python import smact.screening + import pprint elements = ['Ti', 'Al', 'O'] space = smact.element_dictionary(elements) @@ -115,32 +116,50 @@ We can look for neutral combos. eles = [e[1] for e in space.items()] # We set a threshold for the stoichiometry of 4 allowed_combinations = smact.screening.smact_filter(eles, threshold=4) - print(allowed_combinations) - - [(('Ti', 'Al', 'O'), (1, 3, 3)), - (('Ti', 'Al', 'O'), (2, 3, 4)), - (('Ti', 'Al', 'O'), (3, 1, 4)), - (('Ti', 'Al', 'O'), (1, 4, 4)), - (('Ti', 'Al', 'O'), (3, 1, 2)), - (('Ti', 'Al', 'O'), (3, 2, 4)), - (('Ti', 'Al', 'O'), (1, 2, 3)), - (('Ti', 'Al', 'O'), (1, 3, 4)), - (('Ti', 'Al', 'O'), (2, 4, 3)), - (('Ti', 'Al', 'O'), (2, 1, 3)), - (('Ti', 'Al', 'O'), (4, 2, 3)), - (('Ti', 'Al', 'O'), (1, 3, 2)), - (('Ti', 'Al', 'O'), (1, 2, 4)), - (('Ti', 'Al', 'O'), (1, 1, 2)), - (('Ti', 'Al', 'O'), (1, 2, 2)), - (('Ti', 'Al', 'O'), (1, 1, 4)), - (('Ti', 'Al', 'O'), (3, 1, 3)), - (('Ti', 'Al', 'O'), (2, 1, 4)), - (('Ti', 'Al', 'O'), (1, 1, 1)), - (('Ti', 'Al', 'O'), (2, 2, 3)), - (('Ti', 'Al', 'O'), (4, 1, 3)), - (('Ti', 'Al', 'O'), (1, 1, 3)), - (('Ti', 'Al', 'O'), (1, 4, 3)), - (('Ti', 'Al', 'O'), (2, 1, 2))] + pprint.pprint(allowed_combinations) + + [Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -2), stoichiometries=(1, 1, 1)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -2), stoichiometries=(1, 3, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -2), stoichiometries=(2, 4, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -2), stoichiometries=(3, 1, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -2), stoichiometries=(4, 2, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -1), stoichiometries=(1, 1, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -1), stoichiometries=(1, 2, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -1), stoichiometries=(1, 3, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -1), stoichiometries=(2, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 1, -1), stoichiometries=(3, 1, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 2, -2), stoichiometries=(2, 1, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 2, -2), stoichiometries=(2, 2, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 2, -2), stoichiometries=(2, 3, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 2, -2), stoichiometries=(4, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 2, -1), stoichiometries=(1, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 2, -1), stoichiometries=(2, 1, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 3, -2), stoichiometries=(1, 1, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 3, -2), stoichiometries=(3, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(1, 3, -1), stoichiometries=(1, 1, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 1, -2), stoichiometries=(1, 2, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 1, -2), stoichiometries=(1, 4, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 1, -2), stoichiometries=(2, 2, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 1, -2), stoichiometries=(3, 2, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 1, -1), stoichiometries=(1, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 1, -1), stoichiometries=(1, 2, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 2, -2), stoichiometries=(1, 1, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 2, -2), stoichiometries=(1, 2, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 2, -2), stoichiometries=(1, 3, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 2, -2), stoichiometries=(2, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 2, -2), stoichiometries=(3, 1, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 2, -1), stoichiometries=(1, 1, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(2, 3, -2), stoichiometries=(1, 2, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(3, 1, -2), stoichiometries=(1, 1, 2)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(3, 1, -2), stoichiometries=(1, 3, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(3, 1, -1), stoichiometries=(1, 1, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(3, 2, -2), stoichiometries=(2, 1, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(3, 3, -2), stoichiometries=(1, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(4, 1, -2), stoichiometries=(1, 2, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(4, 1, -2), stoichiometries=(1, 4, 4)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(4, 2, -2), stoichiometries=(1, 1, 3)), + Composition(element_symbols=('Ti', 'Al', 'O'), oxidation_states=(4, 2, -2), stoichiometries=(1, 2, 4))] + There is `an example `_ of how this function can be combined with multiprocessing to rapidly explore large subsets of chemical space. diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 7cf7c972..a584b499 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -6,7 +6,7 @@ Getting Started Requirements ============ -The main language is Python 3 and has been tested using Python 3.6+. Basic requirements are Numpy and Scipy. +The main language is Python 3 and has been tested using Python 3.8+. Basic requirements are Numpy and Scipy. The `Atomic Simulation Environment `_ (ASE), `spglib `_, and `pymatgen `_ are also required for many components. diff --git a/examples/Cation_mutation/cation_mutation.ipynb b/examples/Cation_mutation/cation_mutation.ipynb index 25c73c4b..c6091c8a 100644 --- a/examples/Cation_mutation/cation_mutation.ipynb +++ b/examples/Cation_mutation/cation_mutation.ipynb @@ -161,7 +161,9 @@ "def pretty_print_atoms(atoms, linewrap=15):\n", " entries = [\n", " \"{0:5.3f} {1:5.3f} {2:5.3f} {symbol}\".format(*position, symbol=symbol)\n", - " for position, symbol in zip(atoms.get_positions(), atoms.get_chemical_symbols())\n", + " for position, symbol in zip(\n", + " atoms.get_positions(), atoms.get_chemical_symbols()\n", + " )\n", " ]\n", "\n", " for output_i in range(linewrap):\n", @@ -196,13 +198,17 @@ "sub_lattice = distort.build_sub_lattice(single_substitution, \"Ba\")\n", "\n", "# Enumerate the inequivalent sites\n", - "inequivalent_sites = distort.get_inequivalent_sites(sub_lattice, single_substitution)\n", + "inequivalent_sites = distort.get_inequivalent_sites(\n", + " sub_lattice, single_substitution\n", + ")\n", "\n", "# Replace Ba at inequivalent sites with Sr\n", "for i, inequivalent_site in enumerate(inequivalent_sites):\n", " print(\"-\" * hlinewidth)\n", " print(f\"Substituted coordinates {i}\")\n", - " distorted = distort.make_substitution(single_substitution, inequivalent_site, \"Sr\")\n", + " distorted = distort.make_substitution(\n", + " single_substitution, inequivalent_site, \"Sr\"\n", + " )\n", " pretty_print_atoms(distorted)\n", "\n", "print(\"=\" * hlinewidth)" diff --git a/examples/Counting/ElementCombinationsParallel.py b/examples/Counting/ElementCombinationsParallel.py index 8acf02f7..57856c95 100644 --- a/examples/Counting/ElementCombinationsParallel.py +++ b/examples/Counting/ElementCombinationsParallel.py @@ -143,7 +143,9 @@ def main(): # neutral_stoichiometries.update({oxidation_states: count}) def n_neutral_ratios(oxidation_states, threshold=8): - return len(smact.neutral_ratios(oxidation_states, threshold=threshold)[1]) + return len( + smact.neutral_ratios(oxidation_states, threshold=threshold)[1] + ) neutral_stoichiometries = { oxidation_states: n_neutral_ratios( @@ -158,10 +160,14 @@ def n_neutral_ratios(oxidation_states, threshold=8): # progress indicator. combination_count = sum( - count_iter(itertools.combinations(element_list, i)) for i in range(2, n + 1) + count_iter(itertools.combinations(element_list, i)) + for i in range(2, n + 1) ) - print("Counting ({} element combinations)" "...".format(combination_count)) + print( + "Counting ({} element combinations)" + "...".format(combination_count) + ) # Combinations are counted in chunks set by count_progress_interval. # In Python 2.7 the // symbol is "integer division" which rounds @@ -215,7 +221,9 @@ def n_neutral_ratios(oxidation_states, threshold=8): # Serial code path -- iteration over element combinations is # done using the itertools.imap() function. - count = count + sum(map(count_element_combination, imap_arg_generator)) + count = count + sum( + map(count_element_combination, imap_arg_generator) + ) # After each chunk, report the % progress, elapsed time and an # estimate of the remaining time. The smact.pauling_test() calls @@ -229,7 +237,8 @@ def n_neutral_ratios(oxidation_states, threshold=8): time_elapsed = time.time() - start_time time_remaining = ( - combination_count * (time_elapsed / data_pointer) - time_elapsed + combination_count * (time_elapsed / data_pointer) + - time_elapsed ) print_status( @@ -253,7 +262,10 @@ def n_neutral_ratios(oxidation_states, threshold=8): "Number of charge-neutral stoichiometries for combinations " "of {} elements".format(n) ) - print("(using known oxidation states, not including zero): " "{}".format(count)) + print( + "(using known oxidation states, not including zero): " + "{}".format(count) + ) print("") print(f"Total time for counting: {total_time:.3f} sec") diff --git a/examples/Counting/Generate_compositions_lists.ipynb b/examples/Counting/Generate_compositions_lists.ipynb index bfe19235..51d6f255 100644 --- a/examples/Counting/Generate_compositions_lists.ipynb +++ b/examples/Counting/Generate_compositions_lists.ipynb @@ -183,7 +183,9 @@ "# Flatten the list of lists\n", "flat_list = [item for sublist in result for item in sublist]\n", "print(f\"Number of compositions: --> {len(flat_list)} <--\")\n", - "print(\"Each list entry looks like this:\\n elements, oxidation states, stoichiometries\")\n", + "print(\n", + " \"Each list entry looks like this:\\n elements, oxidation states, stoichiometries\"\n", + ")\n", "for i in flat_list[:5]:\n", " print(i)" ] diff --git a/examples/Counting/Raw_combinations.ipynb b/examples/Counting/Raw_combinations.ipynb index 3de6023a..f7b6c465 100644 --- a/examples/Counting/Raw_combinations.ipynb +++ b/examples/Counting/Raw_combinations.ipynb @@ -126,7 +126,12 @@ "\n", "def main():\n", "\n", - " compound_names = {2: \"binary\", 3: \"ternary\", 4: \"quaternary\", 5: \"quinternary\"}\n", + " compound_names = {\n", + " 2: \"binary\",\n", + " 3: \"ternary\",\n", + " 4: \"quaternary\",\n", + " 5: \"quinternary\",\n", + " }\n", "\n", " print(f\"In a search space of {search_space} elements:\")\n", " print(\"\")\n", @@ -145,9 +150,7 @@ " # (1,2,2), (2,1,1), (2,1,2), (2,1,2)\n", " def ineq_ratios_with_coeff(n_species, max_coeff):\n", " print(\n", - " \"Number of inequivalent {} ratios with max coefficient {}: \".format(\n", - " compound_names[n_species], max_coeff\n", - " )\n", + " f\"Number of inequivalent {compound_names[n_species]} ratios with max coefficient {max_coeff}: \"\n", " ),\n", " n_ratios = sum_iter(\n", " filter(\n", @@ -188,7 +191,8 @@ " return n_compounds\n", "\n", " compounds = {\n", - " n: {x: unique_compounds(n, x) for x in range(4, 9, 2)} for n in range(2, 5)\n", + " n: {x: unique_compounds(n, x) for x in range(4, 9, 2)}\n", + " for n in range(2, 5)\n", " }\n", "\n", "\n", diff --git a/examples/Inverse_perovskites/Inverse_formate_perovskites.ipynb b/examples/Inverse_perovskites/Inverse_formate_perovskites.ipynb index 37c65b6f..465a5f44 100644 --- a/examples/Inverse_perovskites/Inverse_formate_perovskites.ipynb +++ b/examples/Inverse_perovskites/Inverse_formate_perovskites.ipynb @@ -178,7 +178,9 @@ " if electroneg_makes_sense:\n", " pauling_perov.append([A[0], B[0], C[0]])\n", " # We calculate the Goldschmidt tolerance factor\n", - " tol = (float(A[2]) + C[2]) / (np.sqrt(2) * (float(B[2]) + C[2]))\n", + " tol = (float(A[2]) + C[2]) / (\n", + " np.sqrt(2) * (float(B[2]) + C[2])\n", + " )\n", " if tol > 1.0:\n", " a_too_large.append([A[0], B[0], C[0]])\n", " anion_hex = anion_hex + 1\n", @@ -354,7 +356,9 @@ "outputs": [], "source": [ "# Get list of Element objects\n", - "search = [el for el in smact.ordered_elements(3, 87) if Element(el).oxidation_states]\n", + "search = [\n", + " el for el in smact.ordered_elements(3, 87) if Element(el).oxidation_states\n", + "]\n", "\n", "# Covert to list of Species objects\n", "all_species = []\n", @@ -363,8 +367,12 @@ " all_species.append(Species(el, oxi_state, \"6_n\"))\n", "\n", "# Define lists of interest\n", - "A_list = [sp for sp in all_species if (sp.oxidation == -1) and (sp.ionic_radius)]\n", - "B_list = [sp for sp in all_species if (4 <= sp.oxidation <= 5) and (sp.ionic_radius)]\n", + "A_list = [\n", + " sp for sp in all_species if (sp.oxidation == -1) and (sp.ionic_radius)\n", + "]\n", + "B_list = [\n", + " sp for sp in all_species if (4 <= sp.oxidation <= 5) and (sp.ionic_radius)\n", + "]\n", "C_list = [Species(\"F\", -1, 4.47)]" ] }, @@ -386,9 +394,9 @@ "for combo in product(A_list, B_list, C_list):\n", " A, B, C = combo[0], combo[1], combo[2]\n", " # Check for charge neutrality in 1:1:3 ratio\n", - " if (1, 1, 3) in screening.neutral_ratios([A.oxidation, B.oxidation, C.oxidation])[\n", - " 1\n", - " ]:\n", + " if (1, 1, 3) in screening.neutral_ratios(\n", + " [A.oxidation, B.oxidation, C.oxidation]\n", + " )[1]:\n", " charge_balanced.append(combo)\n", " # Check for pauling test\n", " if screening.pauling_test(\n", diff --git a/examples/Oxidation_states/oxidation_states.ipynb b/examples/Oxidation_states/oxidation_states.ipynb index f66d6864..ca2bfdf6 100644 --- a/examples/Oxidation_states/oxidation_states.ipynb +++ b/examples/Oxidation_states/oxidation_states.ipynb @@ -174,7 +174,9 @@ "source": [ "# Get the cations that are in the probability table\n", "cations = [\n", - " species for species in ox_prob_finder.get_included_species() if \"-\" not in species\n", + " species\n", + " for species in ox_prob_finder.get_included_species()\n", + " if \"-\" not in species\n", "]\n", "\n", "# Get the symbols of the d-block elements\n", @@ -229,10 +231,17 @@ " cn_e, cn_r = smact.neutral_ratios(ox_states, threshold=8)\n", " if cn_e:\n", " # Electronegativity test\n", - " electroneg_OK = screening.pauling_test(ox_states, electronegativities)\n", + " electroneg_OK = screening.pauling_test(\n", + " ox_states, electronegativities\n", + " )\n", " if electroneg_OK:\n", " compound = tuple(\n", - " [elements, (ox_a, ox_b, -1), cn_r[0], list(els) + [halide_species]]\n", + " [\n", + " elements,\n", + " (ox_a, ox_b, -1),\n", + " cn_r[0],\n", + " list(els) + [halide_species],\n", + " ]\n", " )\n", " all_compounds.append(compound)\n", " return all_compounds" @@ -273,16 +282,20 @@ "# Here we grab the species string for each composition generated by smact\n", "list_of_species = [species[3] for species in flat_list]\n", "A_species = [\n", - " f\"{species[0].symbol}{species[0].oxidation}+\" for species in list_of_species\n", + " f\"{species[0].symbol}{species[0].oxidation}+\"\n", + " for species in list_of_species\n", "]\n", "B_species = [\n", - " f\"{species[1].symbol}{species[1].oxidation}+\" for species in list_of_species\n", + " f\"{species[1].symbol}{species[1].oxidation}+\"\n", + " for species in list_of_species\n", "]\n", "X_species = [f\"{species[2].symbol}1-\" for species in list_of_species]\n", "\n", "# Print out the first few entries from our search space\n", "print(f\"Number of compositions: {len(flat_list)}\")\n", - "print(\"Each list entry looks like this:\\n elements, oxidation states, stoichiometries\")\n", + "print(\n", + " \"Each list entry looks like this:\\n elements, oxidation states, stoichiometries\"\n", + ")\n", "for i in flat_list[:5]:\n", " print(i[0], i[1], i[2])" ] diff --git a/examples/Practical_tutorial/Combinations_practical.ipynb b/examples/Practical_tutorial/Combinations_practical.ipynb index 65fa9bfd..ad9d2a4d 100644 --- a/examples/Practical_tutorial/Combinations_practical.ipynb +++ b/examples/Practical_tutorial/Combinations_practical.ipynb @@ -221,7 +221,9 @@ " for ox_c in smact.Element(ele_c).oxidation_states:\n", " ion_count += 1\n", "\n", - " print(f\"{ele_a} {ox_a} \\t {ele_b} {ox_b} \\t {ele_c} {ox_c}\")\n", + " print(\n", + " f\"{ele_a} {ox_a} \\t {ele_b} {ox_b} \\t {ele_c} {ox_c}\"\n", + " )\n", "\n", "# Prints the total number of combinations found and the time taken to run.\n", "print(f\"Number of combinations = {ion_count}\")\n", @@ -275,7 +277,9 @@ " for ox_c in smact.Element(ele_c).oxidation_states:\n", "\n", " # Checks if the combination is charge neutral before printing it out! #\n", - " cn_e, cn_r = neutral_ratios([ox_a, ox_b, ox_c], threshold=1)\n", + " cn_e, cn_r = neutral_ratios(\n", + " [ox_a, ox_b, ox_c], threshold=1\n", + " )\n", " if cn_e:\n", " charge_neutral_count += 1\n", " print(f\"{ele_a} \\t {ele_b} \\t {ele_c}\")\n", diff --git a/examples/Practical_tutorial/Electronic/scan_energies.py b/examples/Practical_tutorial/Electronic/scan_energies.py index 2db6e475..ae7f8f62 100644 --- a/examples/Practical_tutorial/Electronic/scan_energies.py +++ b/examples/Practical_tutorial/Electronic/scan_energies.py @@ -58,7 +58,10 @@ EA = float(inp[2]) IP = float(inp[3]) if Eg > 2.0: - if EA >= options.EA - window * 0.5 and EA <= options.EA + window * 0.5: + if ( + EA >= options.EA - window * 0.5 + and EA <= options.EA + window * 0.5 + ): ETL.append(inp[0]) if Eg < options.gap: conducting_ETL.append(inp[0]) diff --git a/examples/Practical_tutorial/Lattice/LatticeMatch.py b/examples/Practical_tutorial/Lattice/LatticeMatch.py index 1115928b..90375c93 100644 --- a/examples/Practical_tutorial/Lattice/LatticeMatch.py +++ b/examples/Practical_tutorial/Lattice/LatticeMatch.py @@ -301,14 +301,22 @@ def surface_ratios(surface_a, surface_b, threshold=0.05, limit=5): if index_a != (0, 0, 0): vec1, vec2 = surface_vectors(xtalA, index_a) r_vec1, r_vec2 = reduce_vectors(vec1, vec2) - surface_vector_1 = (length(r_vec1), length(r_vec2), angle(r_vec1, r_vec2)) + surface_vector_1 = ( + length(r_vec1), + length(r_vec2), + angle(r_vec1, r_vec2), + ) # Set the values for material B indices_b = list(itertools.product([0, 1], repeat=3)) for index_b in indices_b: if index_b != (0, 0, 0): vec1, vec2 = surface_vectors(xtalB, index_b) r_vec1, r_vec2 = reduce_vectors(vec1, vec2) - surface_vector_2 = (length(r_vec1), length(r_vec2), angle(r_vec1, r_vec2)) + surface_vector_2 = ( + length(r_vec1), + length(r_vec2), + angle(r_vec1, r_vec2), + ) epitaxy, a, b, strains = surface_ratios( surface_vector_1, surface_vector_2, @@ -336,7 +344,9 @@ def surface_ratios(surface_a, surface_b, threshold=0.05, limit=5): print("Surface super-cell vector: ", surface_super_cell_b) print("------ ------ ------ ------ ------") else: - new = Pair(material1, material2, index_a, index_b, a, b, strains) + new = Pair( + material1, material2, index_a, index_b, a, b, strains + ) isnew = True if len(matched_pairs) == 0 and new.strains[2] == 0.0: matched_pairs.append(new) diff --git a/examples/Practical_tutorial/Site/LatticeSite.py b/examples/Practical_tutorial/Site/LatticeSite.py index 57d9d7f6..1536e436 100644 --- a/examples/Practical_tutorial/Site/LatticeSite.py +++ b/examples/Practical_tutorial/Site/LatticeSite.py @@ -276,16 +276,27 @@ def surface_ratios(surface_a, surface_b, threshold=0.05, limit=5): if index_a != (0, 0, 0): vec1, vec2 = surface_vectors(xtalA, index_a) r_vec1, r_vec2 = reduce_vectors(vec1, vec2) - surface_vector_1 = (length(r_vec1), length(r_vec2), angle(r_vec1, r_vec2)) + surface_vector_1 = ( + length(r_vec1), + length(r_vec2), + angle(r_vec1, r_vec2), + ) # Set the values for material B indices_b = list(itertools.product([0, 1], repeat=3)) for index_b in indices_b: if index_b != (0, 0, 0): vec1, vec2 = surface_vectors(xtalB, index_b) r_vec1, r_vec2 = reduce_vectors(vec1, vec2) - surface_vector_2 = (length(r_vec1), length(r_vec2), angle(r_vec1, r_vec2)) + surface_vector_2 = ( + length(r_vec1), + length(r_vec2), + angle(r_vec1, r_vec2), + ) epitaxy, a, b, strains = surface_ratios( - surface_vector_1, surface_vector_2, threshold=options.strain, limit=5 + surface_vector_1, + surface_vector_2, + threshold=options.strain, + limit=5, ) if epitaxy: if options.verbose: @@ -307,7 +318,9 @@ def surface_ratios(surface_a, surface_b, threshold=0.05, limit=5): print("Surface super-cell vector: ", surface_super_cell_b) print("------ ------ ------ ------ ------") else: - new = Pair(material1, material2, index_a, index_b, a, b, strains, 0.0) + new = Pair( + material1, material2, index_a, index_b, a, b, strains, 0.0 + ) isnew = True if len(matched_pairs) == 0 and new.strains[2] == 0.0: matched_pairs.append(new) diff --git a/examples/Practical_tutorial/Site/SiteMatch.py b/examples/Practical_tutorial/Site/SiteMatch.py index df67477a..91a2a3ce 100755 --- a/examples/Practical_tutorial/Site/SiteMatch.py +++ b/examples/Practical_tutorial/Site/SiteMatch.py @@ -47,7 +47,9 @@ def find_max_csl(surfs_1, surfs_2, multiplicity1, multiplicity2): for i in np.arange(0, 1, 0.1): for j in np.arange(0, 1, 0.1): t_surf = translate(surf_2_super, [i, j]) - csl_values.append(csl(surf_1_super, t_surf, multiplicity1)) + csl_values.append( + csl(surf_1_super, t_surf, multiplicity1) + ) return max(csl_values) diff --git a/examples/Practical_tutorial/Site/csl.py b/examples/Practical_tutorial/Site/csl.py index 737a15e1..38720db7 100755 --- a/examples/Practical_tutorial/Site/csl.py +++ b/examples/Practical_tutorial/Site/csl.py @@ -47,7 +47,9 @@ def find_max_csl(surfs_1, surfs_2, multiplicity1, multiplicity2): for i in np.arange(0, 1, 0.1): for j in np.arange(0, 1, 0.1): t_surf = translate(surf_2_super, [i, j]) - csl_values.append(csl(surf_1_super, t_surf, multiplicity1)) + csl_values.append( + csl(surf_1_super, t_surf, multiplicity1) + ) return max(csl_values) diff --git a/examples/Practical_tutorial/Site/surface_points.py b/examples/Practical_tutorial/Site/surface_points.py index 0a4f94e3..5f53d5dd 100644 --- a/examples/Practical_tutorial/Site/surface_points.py +++ b/examples/Practical_tutorial/Site/surface_points.py @@ -156,10 +156,22 @@ def WO3(miller): surfaces = {} surfaces["100"] = ( ([0.0, 0.25], [0, 0.75], [0.5, 0.25], [0.5, 0.75]), - ([0.25, 0.2], [0.5, 0.25], [0.5, 0.75], [0.7, 0.7], [0.5, 1.0], [1.0, 1.0]), + ( + [0.25, 0.2], + [0.5, 0.25], + [0.5, 0.75], + [0.7, 0.7], + [0.5, 1.0], + [1.0, 1.0], + ), (), ) - surfaces["110"] = ([0.5, 0.3], [0.75, 0.3], [1.0, 0.8], [0.75, 0.80]), () + surfaces["110"] = ( + [0.5, 0.3], + [0.75, 0.3], + [1.0, 0.8], + [0.75, 0.80], + ), () return surfaces[miller] else: print( @@ -174,7 +186,11 @@ def perovskite(miller): exists = ["100", "110", "112"] if miller in exists: surfaces = {} - surfaces["100"] = ([0, 0], [0.5, 0.5]), ([0.5, 0], [0, 0.5], [0.5, 0.5]), () + surfaces["100"] = ( + ([0, 0], [0.5, 0.5]), + ([0.5, 0], [0, 0.5], [0.5, 0.5]), + (), + ) surfaces["112"] = ([0.0, 0.0], [0.5, 0.0], [0.5, 0.5]), () surfaces["110"] = ([0.0, 0.0], [0.0, 0.5], [0.75, 0.5]), () return surfaces[miller] @@ -191,7 +207,11 @@ def CH3NH3PbI3(miller): exists = ["100", "110", "112"] if miller in exists: surfaces = {} - surfaces["100"] = ([0, 0], [0.5, 0.5]), ([0.5, 0], [0, 0.5], [0.5, 0.5]), () + surfaces["100"] = ( + ([0, 0], [0.5, 0.5]), + ([0.5, 0], [0, 0.5], [0.5, 0.5]), + (), + ) surfaces["112"] = ([0.0, 0.0], [0.5, 0.0], [0.5, 0.5]), () surfaces["110"] = ([0.0, 0.0], [0.0, 0.5], [0.75, 0.5]), () return surfaces[miller] @@ -208,7 +228,11 @@ def SrTiO3(miller): exists = ["100", "110", "112"] if miller in exists: surfaces = {} - surfaces["100"] = ([0, 0], [0.5, 0.5]), ([0.5, 0], [0, 0.5], [0.5, 0.5]), () + surfaces["100"] = ( + ([0, 0], [0.5, 0.5]), + ([0.5, 0], [0, 0.5], [0.5, 0.5]), + (), + ) surfaces["112"] = ([0.0, 0.0], [0.5, 0.0], [0.5, 0.5]), () surfaces["110"] = ([0.0, 0.0], [0.0, 0.5], [0.75, 0.5]), () return surfaces[miller] @@ -226,7 +250,12 @@ def zincblende(miller): if miller in exists: surfaces = {} surfaces["100"] = ([0.75, 0.25], [0.0, 0.0]), () - surfaces["110"] = ([0.25, 0.9], [0.25, 0.4], [0.5, 0.7], [0.5, 0.2]), () + surfaces["110"] = ( + [0.25, 0.9], + [0.25, 0.4], + [0.5, 0.7], + [0.5, 0.2], + ), () return surfaces[miller] else: print( @@ -243,8 +272,18 @@ def CuIz(miller): surfaces = {} surfaces["100"] = ([0.75, 0.25], [0.0, 0.0]), () surfaces["001"] = ([0.75, 0.25], [0.0, 0.0]), () - surfaces["110"] = ([0.25, 0.9], [0.25, 0.4], [0.5, 0.7], [0.5, 0.2]), () - surfaces["011"] = ([0.25, 0.9], [0.25, 0.4], [0.5, 0.7], [0.5, 0.2]), () + surfaces["110"] = ( + [0.25, 0.9], + [0.25, 0.4], + [0.5, 0.7], + [0.5, 0.2], + ), () + surfaces["011"] = ( + [0.25, 0.9], + [0.25, 0.4], + [0.5, 0.7], + [0.5, 0.2], + ), () return surfaces[miller] else: print( @@ -296,7 +335,14 @@ def bixybite(miller): if miller in exists: surfaces = {} surfaces["100"] = ( - ([0.2, 0.9], [0.6, 0.9], [0.9, 0.6], [0.4, 0.4], [0.9, 0.4], [0.7, 0.1]), + ( + [0.2, 0.9], + [0.6, 0.9], + [0.9, 0.6], + [0.4, 0.4], + [0.9, 0.4], + [0.7, 0.1], + ), ( [0.2, 0.2], [0.2, 0.7], @@ -379,7 +425,12 @@ def wurtzite(miller): surfaces = {} surfaces["100"] = ([0, 0], [0, 0.37]), () surfaces["010"] = ([0, 0], [0, 0.37]), () - surfaces["110"] = ([0, 0.8], [0.37, 0.8], [0.5, 0.17], [0.87, 0.17]), () + surfaces["110"] = ( + [0, 0.8], + [0.37, 0.8], + [0.5, 0.17], + [0.87, 0.17], + ), () return surfaces[miller] else: print( @@ -396,7 +447,12 @@ def GaN(miller): surfaces = {} surfaces["100"] = ([0, 0], [0, 0.37]), () surfaces["010"] = ([0, 0], [0, 0.37]), () - surfaces["110"] = ([0, 0.8], [0.37, 0.8], [0.5, 0.17], [0.87, 0.17]), () + surfaces["110"] = ( + [0, 0.8], + [0.37, 0.8], + [0.5, 0.17], + [0.87, 0.17], + ), () return surfaces[miller] else: print( @@ -413,7 +469,12 @@ def SiC(miller): surfaces = {} surfaces["100"] = ([0, 0], [0, 0.37]), () surfaces["010"] = ([0, 0], [0, 0.37]), () - surfaces["110"] = ([0, 0.8], [0.37, 0.8], [0.5, 0.17], [0.87, 0.17]), () + surfaces["110"] = ( + [0, 0.8], + [0.37, 0.8], + [0.5, 0.17], + [0.87, 0.17], + ), () return surfaces[miller] else: print( diff --git a/examples/Simple_wrappers/band_gap_simple.py b/examples/Simple_wrappers/band_gap_simple.py index d7f68bd8..28921865 100644 --- a/examples/Simple_wrappers/band_gap_simple.py +++ b/examples/Simple_wrappers/band_gap_simple.py @@ -24,9 +24,15 @@ parser = argparse.ArgumentParser( description="Compound band gap estimates from elemental data." ) - parser.add_argument("-a", "--anion", type=str, help="Element symbol for anion.") - parser.add_argument("-c", "--cation", type=str, help="Element symbol for cation.") - parser.add_argument("-d", "--distance", type=float, help="Internuclear separation.") + parser.add_argument( + "-a", "--anion", type=str, help="Element symbol for anion." + ) + parser.add_argument( + "-c", "--cation", type=str, help="Element symbol for cation." + ) + parser.add_argument( + "-d", "--distance", type=float, help="Internuclear separation." + ) parser.add_argument( "-v", "--verbose", action="store_true", help="More Verbose output." ) diff --git a/examples/Solar_oxides/SolarOxides.ipynb b/examples/Solar_oxides/SolarOxides.ipynb index 6f305f3c..3f67b56b 100644 --- a/examples/Solar_oxides/SolarOxides.ipynb +++ b/examples/Solar_oxides/SolarOxides.ipynb @@ -92,14 +92,18 @@ "\n", " # For each set of species (in oxidation states) apply both SMACT tests\n", " for ox_a, ox_b, ox_c in product(\n", - " els[0].oxidation_states, els[1].oxidation_states, els[2].oxidation_states\n", + " els[0].oxidation_states,\n", + " els[1].oxidation_states,\n", + " els[2].oxidation_states,\n", " ):\n", " ox_states = [ox_a, ox_b, ox_c, -2]\n", " # Test for charge balance\n", " cn_e, cn_r = smact.neutral_ratios(ox_states, threshold=8)\n", " if cn_e:\n", " # Electronegativity test\n", - " electroneg_OK = screening.pauling_test(ox_states, electronegativities)\n", + " electroneg_OK = screening.pauling_test(\n", + " ox_states, electronegativities\n", + " )\n", " if electroneg_OK:\n", " compound = tuple([elements, cn_r[0]])\n", " all_compounds.append(compound)\n", @@ -236,7 +240,9 @@ } ], "source": [ - "new_data = pd.DataFrame(unique_pretty_formulas).rename(columns={0: \"pretty_formula\"})\n", + "new_data = pd.DataFrame(unique_pretty_formulas).rename(\n", + " columns={0: \"pretty_formula\"}\n", + ")\n", "new_data = new_data.drop_duplicates(subset=\"pretty_formula\")\n", "new_data.describe()" ] diff --git a/examples/Structure_Prediction/Li-Garnet_Generator.ipynb b/examples/Structure_Prediction/Li-Garnet_Generator.ipynb index 8af11433..0b64e0a2 100644 --- a/examples/Structure_Prediction/Li-Garnet_Generator.ipynb +++ b/examples/Structure_Prediction/Li-Garnet_Generator.ipynb @@ -12,7 +12,13 @@ "from datetime import datetime\n", "\n", "import smact\n", - "from smact import Element, Species, element_dictionary, neutral_ratios, ordered_elements\n", + "from smact import (\n", + " Element,\n", + " Species,\n", + " element_dictionary,\n", + " neutral_ratios,\n", + " ordered_elements,\n", + ")\n", "from smact.screening import pauling_test, smact_filter" ] }, @@ -108,7 +114,9 @@ "source": [ "# Generate the A-B-C-D combinations\n", "\n", - "ABCD_pairs = [(x, y, z, a) for x in A_els for y in B_els for z in C_els for a in D_els]\n", + "ABCD_pairs = [\n", + " (x, y, z, a) for x in A_els for y in B_els for z in C_els for a in D_els\n", + "]\n", "\n", "# Prove to ourselves that we have all unique chemical systems\n", "print(f\"We have generated {len(ABCD_pairs)} potential compounds\")\n", @@ -210,7 +218,9 @@ "# Flatten the list of lists\n", "flat_list = [item for sublist in result for item in sublist]\n", "print(f\"Number of compositions: --> {len(flat_list)} <--\")\n", - "print(\"Each list entry looks like this:\\n elements, oxidation states, stoichiometries\")\n", + "print(\n", + " \"Each list entry looks like this:\\n elements, oxidation states, stoichiometries\"\n", + ")\n", "for i in flat_list[:5]:\n", " print(i)" ] @@ -275,7 +285,8 @@ " sus_factor = 0\n", " for i in Composition(comp).elements:\n", " sus_factor += (\n", - " Composition(comp).get_wt_fraction(i) * smact.Element(i.symbol).HHI_r\n", + " Composition(comp).get_wt_fraction(i)\n", + " * smact.Element(i.symbol).HHI_r\n", " )\n", " return sus_factor\n", "\n", diff --git a/examples/Structure_Prediction/Li-Garnets_SP-Pym-new.ipynb b/examples/Structure_Prediction/Li-Garnets_SP-Pym-new.ipynb index 7d99f58f..bc74a81a 100644 --- a/examples/Structure_Prediction/Li-Garnets_SP-Pym-new.ipynb +++ b/examples/Structure_Prediction/Li-Garnets_SP-Pym-new.ipynb @@ -520,7 +520,9 @@ "parents_list = []\n", "probs_list = []\n", "for test_specs in test_specs_list:\n", - " predictions = list(SP.predict_structs(test_specs, thresh=10e-4, include_same=False))\n", + " predictions = list(\n", + " SP.predict_structs(test_specs, thresh=10e-4, include_same=False)\n", + " )\n", " predictions.sort(key=itemgetter(1), reverse=True)\n", " parents = [x[2].composition() for x in predictions]\n", " probs = [x[1] for x in predictions]\n", @@ -958,7 +960,9 @@ "import seaborn as sns\n", "\n", "plt.figure(figsize=(8, 6))\n", - "ax1 = sns.histplot(data=results, x=\"probability\", hue=\"In DB?\", multiple=\"stack\")\n", + "ax1 = sns.histplot(\n", + " data=results, x=\"probability\", hue=\"In DB?\", multiple=\"stack\"\n", + ")\n", "# ax2=sns.histplot(new_series, label=\"New Garnets\")\n", "# plt.savefig(\"Prediction_Probability_Distribution_pym.png\")" ] diff --git a/pyproject.toml b/pyproject.toml index ef964667..062dedf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,11 @@ [build-system] # Minimum requirements for the build system to execute. requires = ["setuptools"] # PEP 508 specifications. +build-backend = "setuptools.build_meta" [tool.semantic_release] version_variable = "setup.py:__version__" -version_source = "tag" \ No newline at end of file +version_source = "tag" + +[tool.black] +line-length = 79 \ No newline at end of file diff --git a/setup.py b/setup.py index 30bb7cfe..f4fdbeb8 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,9 @@ #!/usr/bin/env python __author__ = "Daniel W. Davies" -__copyright__ = "Copyright Daniel W. Davies, Adam J. Jackson, Keith T. Butler (2019)" +__copyright__ = ( + "Copyright Daniel W. Davies, Adam J. Jackson, Keith T. Butler (2019)" +) __version__ = "2.4.2" __maintainer__ = "Anthony O. Onwuli" __email__ = "anthony.onwuli16@imperial.ac.uk" @@ -54,10 +56,16 @@ "pathos", ], classifiers=[ - "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", "Operating System :: OS Independent", + "License :: OSI Approved :: MIT License", "Topic :: Scientific/Engineering", + "Topic :: Chemistry", ], + python_requires=">=3.8", ) diff --git a/smact/__init__.py b/smact/__init__.py index 0e71f0f5..b1a51020 100644 --- a/smact/__init__.py +++ b/smact/__init__.py @@ -9,6 +9,7 @@ from math import gcd from operator import mul as multiply from os import path +from typing import Iterable, List, Optional, Sequence, Tuple, Union import pandas as pd @@ -76,7 +77,7 @@ class Element: """ - def __init__(self, symbol): + def __init__(self, symbol: str): """Initialise Element class Args: @@ -130,7 +131,10 @@ def __init__(self, symbol): ("mass", dataset["Mass"]), ("name", dataset["Name"]), ("number", dataset["Z"]), - ("oxidation_states", data_loader.lookup_element_oxidation_states(symbol)), + ( + "oxidation_states", + data_loader.lookup_element_oxidation_states(symbol), + ), ( "oxidation_states_icsd", data_loader.lookup_element_oxidation_states_icsd(symbol), @@ -197,7 +201,13 @@ class Species(Element): """ - def __init__(self, symbol, oxidation, coordination=4, radii_source="shannon"): + def __init__( + self, + symbol: str, + oxidation: int, + coordination: int = 4, + radii_source: str = "shannon", + ): Element.__init__(self, symbol) self.oxidation = oxidation @@ -209,21 +219,28 @@ def __init__(self, symbol, oxidation, coordination=4, radii_source="shannon"): if radii_source == "shannon": - shannon_data = data_loader.lookup_element_shannon_radius_data(symbol) + shannon_data = data_loader.lookup_element_shannon_radius_data( + symbol + ) elif radii_source == "extended": - shannon_data = data_loader.lookup_element_shannon_radius_data_extendedML( - symbol + shannon_data = ( + data_loader.lookup_element_shannon_radius_data_extendedML( + symbol + ) ) else: - print("Data source not recognised. Please select 'shannon' or 'extended'. ") + print( + "Data source not recognised. Please select 'shannon' or 'extended'. " + ) if shannon_data: for dataset in shannon_data: if ( dataset["charge"] == oxidation - and str(coordination) == dataset["coordination"].split("_")[0] + and str(coordination) + == dataset["coordination"].split("_")[0] ): self.shannon_radius = dataset["crystal_radius"] @@ -234,7 +251,8 @@ def __init__(self, symbol, oxidation, coordination=4, radii_source="shannon"): for dataset in shannon_data: if ( dataset["charge"] == oxidation - and str(coordination) == dataset["coordination"].split("_")[0] + and str(coordination) + == dataset["coordination"].split("_")[0] ): self.ionic_radius = dataset["ionic_radius"] @@ -247,7 +265,9 @@ def __init__(self, symbol, oxidation, coordination=4, radii_source="shannon"): shannon_data_df = pd.DataFrame(shannon_data) # Get the rows corresponding to the oxidation state of the species - charge_rows = shannon_data_df.loc[shannon_data_df["charge"] == oxidation] + charge_rows = shannon_data_df.loc[ + shannon_data_df["charge"] == oxidation + ] # Get the mean self.average_shannon_radius = charge_rows["crystal_radius"].mean() @@ -266,7 +286,7 @@ def __init__(self, symbol, oxidation, coordination=4, radii_source="shannon"): self.SSE_2015 = None -def ordered_elements(x, y): +def ordered_elements(x: int, y: int) -> List[str]: """ Return a list of element symbols, ordered by proton number in the range x -> y Args: @@ -288,7 +308,7 @@ def ordered_elements(x, y): return ordered_elements -def element_dictionary(elements=None): +def element_dictionary(elements: Optional[Iterable[str]] = None): """ Create a dictionary of initialised smact.Element objects @@ -307,8 +327,9 @@ def element_dictionary(elements=None): return {symbol: Element(symbol) for symbol in elements} -def are_eq(A, B, tolerance=1e-4): +def are_eq(A: list, B: list, tolerance: float = 1e-4): """Check two arrays for tolerance [1,2,3]==[1,2,3]; but [1,3,2]!=[1,2,3] + Args: A, B (lists): 1-D list of values for approximate equality comparison tolerance: numerical precision for equality condition @@ -348,7 +369,7 @@ def lattices_are_same(lattice1, lattice2, tolerance=1e-4): return lattices_are_same -def _gcd_recursive(*args): +def _gcd_recursive(*args: Iterable[int]): """ Get the greatest common denominator among any number of ints """ @@ -358,7 +379,7 @@ def _gcd_recursive(*args): return gcd(args[0], _gcd_recursive(*args[1:])) -def _isneutral(oxidations, stoichs): +def _isneutral(oxidations: Tuple[int, ...], stoichs: Tuple[int, ...]): """ Check if set of oxidation states is neutral in given stoichiometry @@ -369,7 +390,11 @@ def _isneutral(oxidations, stoichs): return 0 == sum(map(multiply, oxidations, stoichs)) -def neutral_ratios_iter(oxidations, stoichs=False, threshold=5): +def neutral_ratios_iter( + oxidations: List[int], + stoichs: Union[bool, List[int]] = False, + threshold: Optional[int] = 5, +): """ Iterator for charge-neutral stoichiometries @@ -399,7 +424,9 @@ def neutral_ratios_iter(oxidations, stoichs=False, threshold=5): ) -def neutral_ratios(oxidations, stoichs=False, threshold=5): +def neutral_ratios( + oxidations: List[int], stoichs: Union[bool, List[int]] = False, threshold=5 +): """ Get a list of charge-neutral compounds @@ -431,7 +458,10 @@ def neutral_ratios(oxidations, stoichs=False, threshold=5): states which yield a charge-neutral structure """ allowed_ratios = [ - x for x in neutral_ratios_iter(oxidations, stoichs=stoichs, threshold=threshold) + x + for x in neutral_ratios_iter( + oxidations, stoichs=stoichs, threshold=threshold + ) ] return (len(allowed_ratios) > 0, allowed_ratios) diff --git a/smact/builder.py b/smact/builder.py index e7477cf8..d79695bc 100644 --- a/smact/builder.py +++ b/smact/builder.py @@ -12,7 +12,9 @@ from smact.lattice import Lattice, Site -def cubic_perovskite(species, cell_par=[6, 6, 6, 90, 90, 90], repetitions=[1, 1, 1]): +def cubic_perovskite( + species, cell_par=[6, 6, 6, 90, 90, 90], repetitions=[1, 1, 1] +): """ Build a perovskite cell using the crystal function in ASE. diff --git a/smact/data_loader.py b/smact/data_loader.py index 577e81bd..79631a8e 100644 --- a/smact/data_loader.py +++ b/smact/data_loader.py @@ -100,7 +100,8 @@ def lookup_element_oxidation_states(symbol, copy=True): else: if _print_warnings: print( - "WARNING: Oxidation states for element {} " "not found.".format(symbol) + "WARNING: Oxidation states for element {} " + "not found.".format(symbol) ) return None @@ -144,13 +145,16 @@ def lookup_element_oxidation_states_icsd(symbol, copy=True): # _el_ox_states_icsd stores lists -> if copy is set, make an implicit # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [oxidationState for oxidationState in _el_ox_states_icsd[symbol]] + return [ + oxidationState for oxidationState in _el_ox_states_icsd[symbol] + ] else: return _el_ox_states_icsd[symbol] else: if _print_warnings: print( - "WARNING: Oxidation states for element {}" "not found.".format(symbol) + "WARNING: Oxidation states for element {}" + "not found.".format(symbol) ) return None @@ -196,13 +200,16 @@ def lookup_element_oxidation_states_sp(symbol, copy=True): # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [oxidationState for oxidationState in _el_ox_states_sp[symbol]] + return [ + oxidationState for oxidationState in _el_ox_states_sp[symbol] + ] else: return _el_ox_states_sp[symbol] else: if _print_warnings: print( - "WARNING: Oxidation states for element {} " "not found.".format(symbol) + "WARNING: Oxidation states for element {} " + "not found.".format(symbol) ) return None @@ -248,13 +255,16 @@ def lookup_element_oxidation_states_wiki(symbol, copy=True): # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [oxidationState for oxidationState in _el_ox_states_wiki[symbol]] + return [ + oxidationState for oxidationState in _el_ox_states_wiki[symbol] + ] else: return _el_ox_states_wiki[symbol] else: if _print_warnings: print( - "WARNING: Oxidation states for element {} " "not found.".format(symbol) + "WARNING: Oxidation states for element {} " + "not found.".format(symbol) ) return None @@ -290,13 +300,18 @@ def lookup_element_hhis(symbol): if line[0] != "#": items = line.split() - _element_hhis[items[0]] = (float(items[1]), float(items[2])) + _element_hhis[items[0]] = ( + float(items[1]), + float(items[2]), + ) if symbol in _element_hhis: return _element_hhis[symbol] else: if _print_warnings: - print("WARNING: HHI data for element " "{} not found.".format(symbol)) + print( + "WARNING: HHI data for element " "{} not found.".format(symbol) + ) return None @@ -344,13 +359,17 @@ def lookup_element_data(symbol, copy=True): "ion_pot", "dipol", ) - for items in _get_data_rows(os.path.join(data_directory, "element_data.txt")): + for items in _get_data_rows( + os.path.join(data_directory, "element_data.txt") + ): # First two columns are strings and should be left intact # Everything else is numerical and should be cast to a float # or, if not clearly a number, to None clean_items = items[0:2] + list(map(float_or_None, items[2:])) - _element_data.update({items[0]: dict(list(zip(keys, clean_items)))}) + _element_data.update( + {items[0]: dict(list(zip(keys, clean_items)))} + ) if symbol in _element_data: if copy: @@ -365,7 +384,9 @@ def lookup_element_data(symbol, copy=True): return _element_data[symbol] else: if _print_warnings: - print("WARNING: Elemental data for {}" " not found.".format(symbol)) + print( + "WARNING: Elemental data for {}" " not found.".format(symbol) + ) print(_element_data) return None @@ -450,7 +471,9 @@ def lookup_element_shannon_radius_data(symbol, copy=True): # function on each element. # The dictionary values are all Python "value types", so # nothing further is required to make a deep copy. - return [item.copy() for item in _element_shannon_radii_data[symbol]] + return [ + item.copy() for item in _element_shannon_radii_data[symbol] + ] else: return _element_shannon_radii_data[symbol] else: @@ -552,7 +575,8 @@ def lookup_element_shannon_radius_data_extendedML(symbol, copy=True): # The dictionary values are all Python "value types", so # nothing further is required to make a deep copy. return [ - item.copy() for item in _element_shannon_radii_data_extendedML[symbol] + item.copy() + for item in _element_shannon_radii_data_extendedML[symbol] ] else: return _element_shannon_radii_data_extendedML[symbol] diff --git a/smact/dopant_prediction/__init__.py b/smact/dopant_prediction/__init__.py index e69de29b..9f21268f 100644 --- a/smact/dopant_prediction/__init__.py +++ b/smact/dopant_prediction/__init__.py @@ -0,0 +1,16 @@ +"""Minimalist dopant prediction tools for materials design.""" + +import logging + +__author__ = "Chloe (Jiwoo) Lee (이지우)" +__credits__ = { + "WMD Group", + "Imperial College London", + "Anthony Onwuli", + "Keith Butler", + "Aron Walsh", + "Chloe (Jiwoo) Lee (이지우)", +} +__status__ = "Development" + +logger = logging.getLogger(__name__) diff --git a/smact/dopant_prediction/doper.py b/smact/dopant_prediction/doper.py index 05a06306..5c7e867a 100644 --- a/smact/dopant_prediction/doper.py +++ b/smact/dopant_prediction/doper.py @@ -6,11 +6,11 @@ # Now 'Doper' can generate possible n-type p-type dopants for multicomponent materials (i.e. Ternary, Quaternary etc). # Can plot the result of doping search within a single step -"""ex) test= Doper(('Cu1+','Zn2+','Ge4+','S2-')) - test.get_dopants(num_dopants = 10, plot_heatmap = True)""" +# """ex) test= Doper(('Cu1+','Zn2+','Ge4+','S2-')) +# test.get_dopants(num_dopants = 10, plot_heatmap = True)""" -from typing import Tuple +from typing import List, Tuple from pymatgen.util import plotting @@ -22,13 +22,30 @@ class Doper: """ A class to search for n & p type dopants Methods: get_dopants, plot_dopants + + Attributes: + original_species: A tuple which describes the constituent species of a material. For example: + + >>> test= Doper(('Cu1+','Zn2+','Ge4+','S2-')) + >>> test.original_species + ('Cu1+','Zn2+','Ge4+','S2-') + """ def __init__(self, original_species: Tuple[str, ...]): + """ + Intialise the `Doper` class with a tuple of species + + Args: + original_species: See :class:`~.Doper`. + + """ self.original_species = original_species - def _get_cation_dopants(self, element_objects, cations): + def _get_cation_dopants( + self, element_objects: List[smact.Element], cations: List[str] + ): poss_n_type_cat = [] poss_p_type_cat = [] @@ -50,7 +67,9 @@ def _get_cation_dopants(self, element_objects, cations): return poss_n_type_cat, poss_p_type_cat - def _get_anion_dopants(self, element_objects, anions): + def _get_anion_dopants( + self, element_objects: List[smact.Element], anions: List[str] + ): poss_n_type_an = [] poss_p_type_an = [] @@ -70,7 +89,7 @@ def _get_anion_dopants(self, element_objects, anions): ) return poss_n_type_an, poss_p_type_an - def _plot_dopants(self, results): + def _plot_dopants(self, results: dict): """ Uses pymatgen plotting utilities to plot the results of doping search """ @@ -85,20 +104,29 @@ def _plot_dopants(self, results): def get_dopants( self, - num_dopants=5, - plot_heatmap=False, + num_dopants: int = 5, + plot_heatmap: bool = False, ) -> dict: """ Args: - ex) get_dopants(('Ti4+','O2-')) - - original_species (tuple(str)) = ('Cd2+', 'O2-') - num_dopants (int) = The number of suggestions to return for n- and p-type dopants. - + num_dopants (int): The number of suggestions to return for n- and p-type dopants. + plot_heatmap (bool): If True, the results of the doping search are plotted as heatmaps Returns: (dict): Dopant suggestions, given as a dictionary with keys "n_type_cation", "p_type_cation", "n_type_anion", "p_type_anion". + + Examples: + >>> test = Doper(('Ti4+','O2-')) + >>> print(test.get_dopants(num_dopants=2)) + {'n-type cation substitutions': [('Ta5+', 8.790371775858281e-05), + ('Nb5+', 7.830035204694342e-05)], + 'p-type cation substitutions': [('Na1+', 0.00010060400812977031), + ('Zn2+', 8.56373996146833e-05)], + 'n-type anion substitutions': [('F1-', 0.01508116810515677), + ('Cl1-', 0.004737202729901607)], + 'p-type anion substitutions': [('N3-', 0.0014663800608945628), + ('C4-', 9.31310255126729e-08)]} """ cations = [] diff --git a/smact/oxidation_states.py b/smact/oxidation_states.py index 52e78a03..cec08a61 100644 --- a/smact/oxidation_states.py +++ b/smact/oxidation_states.py @@ -8,6 +8,7 @@ import json from os import path +from typing import Dict, Optional, Tuple from numpy import mean from pymatgen.core import Structure @@ -22,7 +23,9 @@ class Oxidation_state_probability_finder: to compute the likelihood of metal species existing in solids in the presence of certain anions. """ - def __init__(self, probability_table=None): + def __init__( + self, probability_table: Optional[Dict[Tuple[str, str], float]] = None + ): """ Args: probability_table (dict): Lookup table to get probabilities for anion-cation pairs. @@ -31,7 +34,9 @@ def __init__(self, probability_table=None): """ if probability_table == None: with open( - path.join(data_directory, "oxidation_state_probability_table.json") + path.join( + data_directory, "oxidation_state_probability_table.json" + ) ) as f: probability_data = json.load(f) # Put data into the required format @@ -49,7 +54,7 @@ def __init__(self, probability_table=None): self._included_cations = included_cations self._included_anions = included_anions - def _generate_lookup_key(self, species1, species2): + def _generate_lookup_key(self, species1: Species, species2: Species): """ Internal function to generate keys to lookup table. @@ -76,17 +81,17 @@ def _generate_lookup_key(self, species1, species2): an_key = "".join([anion.symbol, str(int(anion.oxidation))]) # Check that both the species are included in the probability table - if not all(elem in self._included_species for elem in [an_key, cat_key]): + if not all( + elem in self._included_species for elem in [an_key, cat_key] + ): raise NameError( - "One or both of [{}, {}] are not in the probability table.".format( - cat_key, an_key - ) + f"One or both of [{cat_key}, {an_key}] are not in the probability table." ) table_key = (an_key, cat_key) return table_key - def pair_probability(self, species1, species2): + def pair_probability(self, species1: Species, species2: Species) -> float: """ Get the anion-cation oxidation state probability for a provided pair of smact Species. i.e. :math:`P_{SA}=\\frac{N_{SX}}{N_{MX}}` in the original paper (DOI:10.1039/C8FD00032H). @@ -110,7 +115,9 @@ def get_included_species(self): """ return self._included_species - def compound_probability(self, structure, ignore_stoichiometry=True): + def compound_probability( + self, structure: Structure, ignore_stoichiometry: bool = True + ) -> float: """ calculate overall probability for structure or composition. @@ -131,7 +138,9 @@ def compound_probability(self, structure, ignore_stoichiometry=True): elif all(isinstance(i, pmgSpecies) for i in structure): structure = [Species(i.symbol, i.oxi_state) for i in structure] else: - raise TypeError("Input requires a list of SMACT or Pymatgen species.") + raise TypeError( + "Input requires a list of SMACT or Pymatgen species." + ) elif type(structure) == Structure: species = structure.species if not all(isinstance(i, pmgSpecies) for i in species): @@ -156,6 +165,8 @@ def compound_probability(self, structure, ignore_stoichiometry=True): species_pairs = list(set(species_pairs)) # Do the maths - pair_probs = [self.pair_probability(pair[0], pair[1]) for pair in species_pairs] + pair_probs = [ + self.pair_probability(pair[0], pair[1]) for pair in species_pairs + ] compound_prob = mean(pair_probs) return compound_prob diff --git a/smact/properties.py b/smact/properties.py index f9b8e86e..722a624f 100644 --- a/smact/properties.py +++ b/smact/properties.py @@ -1,9 +1,11 @@ +from typing import List, Optional, Union + from numpy import product, sqrt import smact -def eneg_mulliken(element): +def eneg_mulliken(element: Union[smact.Element, str]) -> float: """Get Mulliken electronegativity from the IE and EA. Arguments: @@ -23,7 +25,12 @@ def eneg_mulliken(element): return mulliken -def band_gap_Harrison(anion, cation, verbose=False, distance=None): +def band_gap_Harrison( + anion: str, + cation: str, + verbose: bool = False, + distance: Optional[Union[float, str]] = None, +) -> float: """ Estimates the band gap from elemental data. @@ -76,7 +83,12 @@ def band_gap_Harrison(anion, cation, verbose=False, distance=None): return Band_gap -def compound_electroneg(verbose=False, elements=None, stoichs=None, source="Mulliken"): +def compound_electroneg( + verbose: bool = False, + elements: List[Union[str, smact.Element]] = None, + stoichs: List[Union[int, float]] = None, + source: str = "Mulliken", +) -> float: """Estimate electronegativity of compound from elemental data. Uses Mulliken electronegativity by default, which uses elemental @@ -123,7 +135,9 @@ def compound_electroneg(verbose=False, elements=None, stoichs=None, source="Mull elif source == "Pauling": elementlist = [(2.86 * el.pauling_eneg) for el in elementlist] else: - raise Exception(f"Electronegativity type '{source}'", "is not recognised") + raise Exception( + f"Electronegativity type '{source}'", "is not recognised" + ) # Print optional list of element electronegativities. # This may be a useful sanity check in case of a suspicious result. diff --git a/smact/screening.py b/smact/screening.py index 1ea5cc7f..3117a4ae 100644 --- a/smact/screening.py +++ b/smact/screening.py @@ -1,17 +1,27 @@ import itertools import warnings +from collections import namedtuple from itertools import combinations +from typing import Iterable, List, Optional, Tuple, Union from smact import Element, neutral_ratios +# Use named tuple to improve readability of smact_filter outputs +_allowed_compositions = namedtuple( + "Composition", ["element_symbols", "oxidation_states", "stoichiometries"] +) +_allowed_compositions_nonunique = namedtuple( + "Composition", ["element_symbols", "stoichiometries"] +) + def pauling_test( - oxidation_states, - electronegativities, - symbols=[], - repeat_anions=True, - repeat_cations=True, - threshold=0.0, + oxidation_states: List[int], + electronegativities: List[float], + symbols: List[str] = [], + repeat_anions: bool = True, + repeat_cations: bool = True, + threshold: float = 0.0, ): """Check if a combination of ions makes chemical sense, (i.e. positive ions should be of lower electronegativity). @@ -58,7 +68,12 @@ def pauling_test( return False -def _no_repeats(oxidation_states, symbols, repeat_anions=False, repeat_cations=False): +def _no_repeats( + oxidation_states: List[int], + symbols: List[str], + repeat_anions: bool = False, + repeat_cations: bool = False, +): """ Check if any anion or cation appears twice. @@ -92,7 +107,12 @@ def _no_repeats(oxidation_states, symbols, repeat_anions=False, repeat_cations=F def pauling_test_old( - ox, paul, symbols, repeat_anions=True, repeat_cations=True, threshold=0.0 + ox: List[int], + paul: List[float], + symbols: List[str], + repeat_anions: bool = True, + repeat_cations: bool = True, + threshold: float = 0.0, ): """Check if a combination of ions makes chemical sense, (i.e. positive ions should be of lower Pauling electronegativity). @@ -153,7 +173,7 @@ def pauling_test_old( return True -def eneg_states_test(ox_states, enegs): +def eneg_states_test(ox_states: List[int], enegs: List[float]): """ Internal function for checking electronegativity criterion @@ -173,7 +193,9 @@ def eneg_states_test(ox_states, enegs): anions, otherwise False """ - for ((ox1, eneg1), (ox2, eneg2)) in combinations(list(zip(ox_states, enegs)), 2): + for ((ox1, eneg1), (ox2, eneg2)) in combinations( + list(zip(ox_states, enegs)), 2 + ): if (ox1 > 0) and (ox2 < 0) and (eneg1 >= eneg2): return False elif (ox1 < 0) and (ox2 > 0) and (eneg1 <= eneg2): @@ -184,7 +206,9 @@ def eneg_states_test(ox_states, enegs): return True -def eneg_states_test_threshold(ox_states, enegs, threshold=0): +def eneg_states_test_threshold( + ox_states: List[int], enegs: List[float], threshold: Optional[float] = 0 +): """Internal function for checking electronegativity criterion This implementation is fast as it 'short-circuits' as soon as it @@ -208,7 +232,9 @@ def eneg_states_test_threshold(ox_states, enegs, threshold=0): anions, otherwise False """ - for ((ox1, eneg1), (ox2, eneg2)) in combinations(list(zip(ox_states, enegs)), 2): + for ((ox1, eneg1), (ox2, eneg2)) in combinations( + list(zip(ox_states, enegs)), 2 + ): if (ox1 > 0) and (ox2 < 0) and ((eneg1 - eneg2) > threshold): return False elif (ox1 < 0) and (ox2 > 0) and (eneg2 - eneg1) > threshold: @@ -217,7 +243,7 @@ def eneg_states_test_threshold(ox_states, enegs, threshold=0): return True -def eneg_states_test_alternate(ox_states, enegs): +def eneg_states_test_alternate(ox_states: List[int], enegs: List[float]): """Internal function for checking electronegativity criterion This implementation appears to be slightly slower than @@ -243,7 +269,10 @@ def eneg_states_test_alternate(ox_states, enegs): return min_cation_eneg > max_anion_eneg -def ml_rep_generator(composition, stoichs=None): +def ml_rep_generator( + composition: Union[List[Element], List[str]], + stoichs: Optional[List[int]] = None, +): """Function to take a composition of Elements and return a list of values between 0 and 1 that describes the composition, useful for machine learning. @@ -279,7 +308,12 @@ def ml_rep_generator(composition, stoichs=None): return norm -def smact_filter(els, threshold=8, species_unique=True, oxidation_states_set="default"): +def smact_filter( + els: Union[Tuple[Element], List[Element]], + threshold: int = 8, + species_unique: bool = True, + oxidation_states_set: str = "default", +) -> Union[List[Tuple[str, int, int]], List[Tuple[str, int]]]: """Function that applies the charge neutrality and electronegativity tests in one go for simple application in external scripts that wish to apply the general 'smact test'. @@ -288,7 +322,7 @@ def smact_filter(els, threshold=8, species_unique=True, oxidation_states_set="de els (tuple/list): A list of smact.Element objects threshold (int): Threshold for stoichiometry limit, default = 8 species_unique (bool): Whether or not to consider elements in different oxidation states as unique in the results. - oxidation_states_set (string): A string to choose which set of oxidation states should be chosen. Options are 'default', 'icsd', 'pymatgen' and 'wiki' for the default, icsd, pymatgen structure predictor and Wikipedia (https://en.wikipedia.org/wiki/Template:List_of_oxidation_states_of_the_elements) oxidation states. + oxidation_states_set (string): A string to choose which set of oxidation states should be chosen. Options are 'default', 'icsd', 'pymatgen' and 'wiki' for the default, icsd, pymatgen structure predictor and Wikipedia (https://en.wikipedia.org/wiki/Template:List_of_oxidation_states_of_the_elements) oxidation states respectively. Returns: allowed_comps (list): Allowed compositions for that chemical system in the form [(elements), (oxidation states), (ratios)] if species_unique=True @@ -330,13 +364,17 @@ def smact_filter(els, threshold=8, species_unique=True, oxidation_states_set="de electroneg_OK = pauling_test(ox_states, electronegs) if electroneg_OK: for ratio in cn_r: - compositions.append(tuple([symbols, ox_states, ratio])) + compositions.append( + _allowed_compositions(symbols, ox_states, ratio) + ) # Return list depending on whether we are interested in unique species combinations # or just unique element combinations. if species_unique: return compositions else: - compositions = [(i[0], i[2]) for i in compositions] + compositions = [ + _allowed_compositions_nonunique(i[0], i[2]) for i in compositions + ] compositions = list(set(compositions)) return compositions diff --git a/smact/structure_prediction/database.py b/smact/structure_prediction/database.py index f9ac19eb..5f0004e4 100644 --- a/smact/structure_prediction/database.py +++ b/smact/structure_prediction/database.py @@ -91,7 +91,9 @@ def __exit__(self, exc_type, *args): def add_mp_icsd( self, table: str, - mp_data: Optional[List[Dict[str, Union[pymatgen.core.Structure, str]]]] = None, + mp_data: Optional[ + List[Dict[str, Union[pymatgen.core.Structure, str]]] + ] = None, mp_api_key: Optional[str] = None, ) -> int: """Add a table populated with Materials Project-hosted ICSD structures. @@ -194,7 +196,9 @@ def add_structs( return num - def get_structs(self, composition: str, table: str) -> List[SmactStructure]: + def get_structs( + self, composition: str, table: str + ) -> List[SmactStructure]: """Get SmactStructures for a given composition. Args: @@ -256,19 +260,27 @@ def get_with_species( def parse_mprest( data: Dict[str, Union[pymatgen.core.Structure, str]], + determine_oxi: str = "BV", ) -> SmactStructure: """Parse MPRester query data to generate structures. Args: data: A dictionary containing the keys 'structure' and 'material_id', with the associated values. + determine_oxi (str): The method to determine the assignments oxidation states in the structure. + Options are 'BV', 'comp_ICSD','both' for determining the oxidation states by bond valence, + ICSD statistics or trial both sequentially, respectively. Returns: An oxidation-state-decorated :class:`SmactStructure`. """ try: - return SmactStructure.from_py_struct(data["structure"]) + return SmactStructure.from_py_struct( + data["structure"], determine_oxi="BV" + ) except: # Couldn't decorate with oxidation states - logger.warn(f"Couldn't decorate {data['material_id']} with oxidation states.") + logger.warn( + f"Couldn't decorate {data['material_id']} with oxidation states." + ) diff --git a/smact/structure_prediction/mutation.py b/smact/structure_prediction/mutation.py index 20db7fb5..454f3fef 100644 --- a/smact/structure_prediction/mutation.py +++ b/smact/structure_prediction/mutation.py @@ -237,8 +237,12 @@ def _mutate_structure( # Replace sites struct_buff.sites[final_species] = struct_buff.sites.pop(init_species) # And sort - species_strs = struct_buff._format_style("{ele}{charge}{sign}").split(" ") - struct_buff.sites = {spec: struct_buff.sites[spec] for spec in species_strs} + species_strs = struct_buff._format_style("{ele}{charge}{sign}").split( + " " + ) + struct_buff.sites = { + spec: struct_buff.sites[spec] for spec in species_strs + } return struct_buff @@ -264,7 +268,9 @@ def _nary_mutate_structure( struct_buff = deepcopy(structure) init_spec_tup_list = [parse_spec(i) for i in init_species] struct_spec_tups = list(map(itemgetter(0, 1), struct_buff.species)) - spec_loc = [struct_spec_tups.index(init_spec_tup_list[i]) for i in range(n)] + spec_loc = [ + struct_spec_tups.index(init_spec_tup_list[i]) for i in range(n) + ] final_spec_tup_list = [parse_spec(i) for i in final_species] @@ -285,11 +291,17 @@ def _nary_mutate_structure( # Replace sites for i in range(n): - struct_buff.sites[final_species[i]] = struct_buff.sites.pop(init_species[i]) + struct_buff.sites[final_species[i]] = struct_buff.sites.pop( + init_species[i] + ) # And sort - species_strs = struct_buff._format_style("{ele}{charge}{sign}").split(" ") - struct_buff.sites = {spec: struct_buff.sites[spec] for spec in species_strs} + species_strs = struct_buff._format_style("{ele}{charge}{sign}").split( + " " + ) + struct_buff.sites = { + spec: struct_buff.sites[spec] for spec in species_strs + } return struct_buff @@ -360,7 +372,10 @@ def pair_corr(self, s1: str, s2: str) -> float: def cond_sub_prob(self, s1: str, s2: str) -> float: """Calculate the probability of substitution of one species with another.""" - return np.exp(self.get_lambda(s1, s2)) / np.exp(self.get_lambdas(s2)).sum() + return ( + np.exp(self.get_lambda(s1, s2)) + / np.exp(self.get_lambdas(s2)).sum() + ) def cond_sub_probs(self, s1: str) -> pd.Series: """Calculate the probabilities of substitution of a given species. @@ -402,4 +417,7 @@ def unary_substitute( ] ): continue - yield (self._mutate_structure(structure, specie, new_spec), prob) + yield ( + self._mutate_structure(structure, specie, new_spec), + prob, + ) diff --git a/smact/structure_prediction/prediction.py b/smact/structure_prediction/prediction.py index aceac82a..85741ef5 100644 --- a/smact/structure_prediction/prediction.py +++ b/smact/structure_prediction/prediction.py @@ -35,7 +35,9 @@ class StructurePredictor: """ - def __init__(self, mutator: CationMutator, struct_db: StructureDB, table: str): + def __init__( + self, mutator: CationMutator, struct_db: StructureDB, table: str + ): """Initialize class. Args: @@ -142,7 +144,9 @@ def predict_structs( # Poorly decorated continue yield ( - self.cm._mutate_structure(parent, alt_spec, diff_spec_str), + self.cm._mutate_structure( + parent, alt_spec, diff_spec_str + ), p, parent, ) @@ -179,7 +183,8 @@ def nary_predict_structs( sub_species = list(map(list, sub_species)) potential_nary_parents: List[List[SmactStructure]] = list( - self.db.get_with_species(specs, self.table) for specs in sub_species + self.db.get_with_species(specs, self.table) + for specs in sub_species ) for spec_idx, parents in enumerate(potential_nary_parents): @@ -260,7 +265,9 @@ def nary_predict_structs( # Poorly decorated continue yield ( - self.cm._nary_mutate_structure(parent, alt_spec, diff_spec_str), + self.cm._nary_mutate_structure( + parent, alt_spec, diff_spec_str + ), p, parent, ) diff --git a/smact/structure_prediction/structure.py b/smact/structure_prediction/structure.py index cb701d30..b98e2a48 100644 --- a/smact/structure_prediction/structure.py +++ b/smact/structure_prediction/structure.py @@ -12,6 +12,9 @@ import pymatgen from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.ext.matproj import MPRester +from pymatgen.transformations.standard_transformations import ( + OxidationStateDecorationTransformation, +) import smact @@ -66,11 +69,15 @@ def __init__( :meth:`~.from_mp`. """ - self.species = self._sanitise_species(species) if sanitise_species else species + self.species = ( + self._sanitise_species(species) if sanitise_species else species + ) self.lattice_mat = lattice_mat - self.sites = {spec: sites[spec] for spec in self.get_spec_strs()} # Sort sites + self.sites = { + spec: sites[spec] for spec in self.get_spec_strs() + } # Sort sites self.lattice_param = lattice_param @@ -152,7 +159,9 @@ def _sanitise_species( species[0][0], smact.Species ): # Species class variation of instantiation species.sort(key=lambda x: (x[0].symbol, -x[0].oxidation)) - sanit_species = [(x[0].symbol, x[0].oxidation, x[1]) for x in species] + sanit_species = [ + (x[0].symbol, x[0].oxidation, x[1]) for x in species + ] else: raise TypeError(species_error) @@ -179,7 +188,9 @@ def __parse_py_sites( """ if not isinstance(structure, pymatgen.core.Structure): - raise TypeError("structure must be a pymatgen.core.Structure instance.") + raise TypeError( + "structure must be a pymatgen.core.Structure instance." + ) sites = defaultdict(list) for site in structure.sites: @@ -219,21 +230,54 @@ def __parse_py_sites( return sites, species @staticmethod - def from_py_struct(structure: pymatgen.core.Structure): + def from_py_struct( + structure: pymatgen.core.Structure, determine_oxi: str = "BV" + ): """Create a SmactStructure from a pymatgen Structure object. Args: structure: A pymatgen Structure. + determine_oxi (str): The method to determine the assignments oxidation states in the structure. + Options are 'BV', 'comp_ICSD','both' for determining the oxidation states by bond valence, + ICSD statistics or trial both sequentially, respectively. Returns: :class:`~.SmactStructure` """ if not isinstance(structure, pymatgen.core.Structure): - raise TypeError("Structure must be a pymatgen.core.Structure instance.") + raise TypeError( + "Structure must be a pymatgen.core.Structure instance." + ) - bva = BVAnalyzer() - struct = bva.get_oxi_state_decorated_structure(structure) + if determine_oxi == "BV": + bva = BVAnalyzer() + struct = bva.get_oxi_state_decorated_structure(structure) + + elif determine_oxi == "comp_ICSD": + comp = structure.composition + oxi_transform = OxidationStateDecorationTransformation( + comp.oxi_state_guesses()[0] + ) + struct = oxi_transform.apply_transformation(structure) + print("Charge assigned based on ICSD statistics") + + elif determine_oxi == "both": + try: + bva = BVAnalyzer() + struct = bva.get_oxi_state_decorated_structure(structure) + print("Oxidation states assigned using bond valence") + except ValueError: + comp = structure.composition + oxi_transform = OxidationStateDecorationTransformation( + comp.oxi_state_guesses()[0] + ) + struct = oxi_transform.apply_transformation(structure) + print("Oxidation states assigned based on ICSD statistics") + else: + raise ValueError( + f"Argument for 'determine_oxi', <{determine_oxi}> is not valid. Choose either 'BV','comp_ICSD' or 'both'." + ) sites, species = SmactStructure.__parse_py_sites(struct) @@ -253,11 +297,15 @@ def from_py_struct(structure: pymatgen.core.Structure): def from_mp( species: List[Union[Tuple[str, int, int], Tuple[smact.Species, int]]], api_key: str, + determine_oxi: str = "BV", ): """Create a SmactStructure using the first Materials Project entry for a composition. Args: species: See :meth:`~.__init__`. + determine_oxi (str): The method to determine the assignments oxidation states in the structure. + Options are 'BV', 'comp_ICSD','both' for determining the oxidation states by bond valence, + ICSD statistics or trial both sequentially, respectively. api_key: A www.materialsproject.org API key. Returns: @@ -283,10 +331,37 @@ def from_mp( # Default to first found structure struct = structs[0]["structure"] - if 0 not in (spec[1] for spec in sanit_species): # If everything's charged - bva = BVAnalyzer() - struct = bva.get_oxi_state_decorated_structure(struct) - + if 0 not in ( + spec[1] for spec in sanit_species + ): # If everything's charged + if determine_oxi == "BV": + bva = BVAnalyzer() + struct = bva.get_oxi_state_decorated_structure(struct) + + elif determine_oxi == "comp_ICSD": + comp = struct.composition + oxi_transform = OxidationStateDecorationTransformation( + comp.oxi_state_guesses()[0] + ) + struct = oxi_transform.apply_transformation(struct) + print("Charge assigned based on ICSD statistics") + + elif determine_oxi == "both": + try: + bva = BVAnalyzer() + struct = bva.get_oxi_state_decorated_structure(struct) + print("Oxidation states assigned using bond valence") + except ValueError: + comp = struct.composition + oxi_transform = OxidationStateDecorationTransformation( + comp.oxi_state_guesses()[0] + ) + struct = oxi_transform.apply_transformation(struct) + print("Oxidation states assigned based on ICSD statistics") + else: + raise ValueError( + f"Argument for 'determine_oxi', <{determine_oxi}> is not valid. Choose either 'BV','comp_ICSD' or 'both'." + ) lattice_mat = struct.lattice.matrix lattice_param = 1.0 # TODO Use actual lattice parameter @@ -352,7 +427,10 @@ def from_poscar(poscar: str): lattice_param = float(lines[1]) lattice = np.array( - [[float(point) for point in line.split(" ")] for line in lines[2:5]] + [ + [float(point) for point in line.split(" ")] + for line in lines[2:5] + ] ) sites = defaultdict(list) @@ -511,7 +589,9 @@ def as_poscar(self) -> str: poscar += f"{self.lattice_param}\n" poscar += ( - "\n".join(" ".join(map(str, vec)) for vec in self.lattice_mat.tolist()) + "\n".join( + " ".join(map(str, vec)) for vec in self.lattice_mat.tolist() + ) + "\n" ) @@ -520,7 +600,8 @@ def as_poscar(self) -> str: poscar += self._format_style("{ele}") + "\n" poscar += ( - " ".join(str(spec_count[spec]) for spec in self.get_spec_strs()) + "\n" + " ".join(str(spec_count[spec]) for spec in self.get_spec_strs()) + + "\n" ) poscar += "Cartesian\n" diff --git a/smact/tests/test_core.py b/smact/tests/test_core.py index 7e77b9d1..3fe57f32 100755 --- a/smact/tests/test_core.py +++ b/smact/tests/test_core.py @@ -28,7 +28,9 @@ def test_Element_class_Pt(self): self.assertEqual(Pt.dipol, 44.00) def test_ordered_elements(self): - self.assertEqual(smact.ordered_elements(65, 68), ["Tb", "Dy", "Ho", "Er"]) + self.assertEqual( + smact.ordered_elements(65, 68), ["Tb", "Dy", "Ho", "Er"] + ) self.assertEqual(smact.ordered_elements(52, 52), ["Te"]) def test_element_dictionary(self): @@ -41,9 +43,13 @@ def test_element_dictionary(self): def test_are_eq(self): self.assertTrue( - smact.are_eq([1.00, 2.00, 3.00], [1.001, 1.999, 3.00], tolerance=1e-2) + smact.are_eq( + [1.00, 2.00, 3.00], [1.001, 1.999, 3.00], tolerance=1e-2 + ) + ) + self.assertFalse( + smact.are_eq([1.00, 2.00, 3.00], [1.001, 1.999, 3.00]) ) - self.assertFalse(smact.are_eq([1.00, 2.00, 3.00], [1.001, 1.999, 3.00])) def test_gcd_recursive(self): self.assertEqual(smact._gcd_recursive(4, 12, 10, 32), 2) @@ -94,7 +100,9 @@ def test_pauling_test(self): ) ) self.assertFalse( - smact.screening.pauling_test((-2, +2), (Sn.pauling_eneg, S.pauling_eneg)) + smact.screening.pauling_test( + (-2, +2), (Sn.pauling_eneg, S.pauling_eneg) + ) ) self.assertFalse( smact.screening.pauling_test( @@ -133,12 +141,16 @@ def test_pauling_test_old(self): Sn, S = (smact.Element(label) for label in ("Sn", "S")) self.assertTrue( smact.screening.pauling_test_old( - (+2, -2), (Sn.pauling_eneg, S.pauling_eneg), symbols=("S", "S", "Sn") + (+2, -2), + (Sn.pauling_eneg, S.pauling_eneg), + symbols=("S", "S", "Sn"), ) ) self.assertFalse( smact.screening.pauling_test_old( - (-2, +2), (Sn.pauling_eneg, S.pauling_eneg), symbols=("S", "S", "Sn") + (-2, +2), + (Sn.pauling_eneg, S.pauling_eneg), + symbols=("S", "S", "Sn"), ) ) self.assertFalse( @@ -318,8 +330,12 @@ def test_ml_rep_generator(self): 0.0, 0.0, ] - self.assertEqual(smact.screening.ml_rep_generator(["Pb", "O"], [1, 2]), PbO2_ml) - self.assertEqual(smact.screening.ml_rep_generator([Pb, O], [1, 2]), PbO2_ml) + self.assertEqual( + smact.screening.ml_rep_generator(["Pb", "O"], [1, 2]), PbO2_ml + ) + self.assertEqual( + smact.screening.ml_rep_generator([Pb, O], [1, 2]), PbO2_ml + ) def test_smact_filter(self): Na, Fe, Cl = (smact.Element(label) for label in ("Na", "Fe", "Cl")) @@ -344,7 +360,9 @@ def test_Lattice_class(self): # ---------- Lattice parameters ----------- def test_lattice_parameters(self): - perovskite = smact.lattice_parameters.cubic_perovskite([1.81, 1.33, 1.82]) + perovskite = smact.lattice_parameters.cubic_perovskite( + [1.81, 1.33, 1.82] + ) wurtz = smact.lattice_parameters.wurtzite([1.81, 1.33]) self.assertAlmostEqual(perovskite[0], 6.3) self.assertAlmostEqual(perovskite[1], 6.3) @@ -356,9 +374,11 @@ def test_lattice_parameters(self): def test_oxidation_states(self): ox = smact.oxidation_states.Oxidation_state_probability_finder() self.assertAlmostEqual( - ox.compound_probability([Specie("Fe", +3), Specie("O", -2)]), 0.74280230326 + ox.compound_probability([Specie("Fe", +3), Specie("O", -2)]), + 0.74280230326, ) self.assertAlmostEqual( - ox.pair_probability(Species("Fe", +3), Species("O", -2)), 0.74280230326 + ox.pair_probability(Species("Fe", +3), Species("O", -2)), + 0.74280230326, ) self.assertEqual(len(ox.get_included_species()), 173) diff --git a/smact/tests/test_doper.py b/smact/tests/test_doper.py index 6d475245..e6eeb9e8 100644 --- a/smact/tests/test_doper.py +++ b/smact/tests/test_doper.py @@ -11,8 +11,12 @@ def test_dopant_prediction(self): test_specie = ("Cu+", "Ga3+", "S2-") test = doper.Doper(test_specie) - cation_max_charge = max(test_specie, key=lambda x: utilities.parse_spec(x)[1]) - anion_min_charge = min(test_specie, key=lambda x: utilities.parse_spec(x)[1]) + cation_max_charge = max( + test_specie, key=lambda x: utilities.parse_spec(x)[1] + ) + anion_min_charge = min( + test_specie, key=lambda x: utilities.parse_spec(x)[1] + ) _, cat_charge = utilities.parse_spec(cation_max_charge) _, an_charge = utilities.parse_spec(anion_min_charge) @@ -38,7 +42,9 @@ def test_dopant_prediction(self): if __name__ == "__main__": TestLoader = unittest.TestLoader() DoperTests = unittest.TestSuite() - DoperTests.addTests(TestLoader.loadTestsFromTestCase(dopant_prediction_test)) + DoperTests.addTests( + TestLoader.loadTestsFromTestCase(dopant_prediction_test) + ) runner = unittest.TextTestRunner() result = runner.run(DoperTests) diff --git a/smact/tests/test_structure.py b/smact/tests/test_structure.py index bedb97f7..597377e1 100644 --- a/smact/tests/test_structure.py +++ b/smact/tests/test_structure.py @@ -135,10 +135,27 @@ def test_from_py_struct(self): self.assertStructAlmostEqual(s1, s2) + def test_from_py_struct_icsd(self): + """Test generation of SmactStructure from a pymatgen Structure using ICSD statistics to determine oxidation states.""" + with open(TEST_PY_STRUCT) as f: + d = json.load(f) + py_structure = pymatgen.core.Structure.from_dict(d) + + with ignore_warnings(smact.structure_prediction.logger): + s1 = SmactStructure.from_py_struct( + py_structure, determine_oxi="comp_ICSD" + ) + + s2 = SmactStructure.from_file(os.path.join(files_dir, "CaTiO3.txt")) + + self.assertStructAlmostEqual(s1, s2) + def test_has_species(self): """Test determining whether a species is in a `SmactStructure`.""" s1 = SmactStructure( - *self._gen_empty_structure([("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)]) + *self._gen_empty_structure( + [("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)] + ) ) self.assertTrue(s1.has_species(("Ba", 2))) @@ -148,10 +165,14 @@ def test_has_species(self): def test_smactStruc_comp_key(self): """Test generation of a composition key for `SmactStructure`s.""" s1 = SmactStructure( - *self._gen_empty_structure([("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)]) + *self._gen_empty_structure( + [("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)] + ) ) s2 = SmactStructure( - *self._gen_empty_structure([("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)]) + *self._gen_empty_structure( + [("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)] + ) ) Ba = Species("Ba", 2) @@ -160,8 +181,12 @@ def test_smactStruc_comp_key(self): Fe2 = Species("Fe", 2) Fe3 = Species("Fe", 3) - s3 = SmactStructure(*self._gen_empty_structure([(Ba, 2), (O, 1), (F, 2)])) - s4 = SmactStructure(*self._gen_empty_structure([(Fe2, 1), (Fe3, 2), (O, 4)])) + s3 = SmactStructure( + *self._gen_empty_structure([(Ba, 2), (O, 1), (F, 2)]) + ) + s4 = SmactStructure( + *self._gen_empty_structure([(Fe2, 1), (Fe3, 2), (O, 4)]) + ) Ba_2OF_2 = "Ba_2_2+F_2_1-O_1_2-" Fe_3O_4 = "Fe_2_3+Fe_1_2+O_4_2-" @@ -181,7 +206,9 @@ def test_smactStruc_from_file(self): def test_equality(self): """Test equality determination of `SmactStructure`.""" - struct_files = [os.path.join(files_dir, f"{x}.txt") for x in ["CaTiO3", "NaCl"]] + struct_files = [ + os.path.join(files_dir, f"{x}.txt") for x in ["CaTiO3", "NaCl"] + ] CaTiO3 = SmactStructure.from_file(struct_files[0]) NaCl = SmactStructure.from_file(struct_files[1]) @@ -197,19 +224,27 @@ def test_equality(self): def test_ele_stoics(self): """Test acquiring element stoichiometries.""" s1 = SmactStructure( - *self._gen_empty_structure([("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)]) + *self._gen_empty_structure( + [("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)] + ) ) s1_stoics = {"Fe": 3, "O": 4} s2 = SmactStructure( - *self._gen_empty_structure([("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)]) + *self._gen_empty_structure( + [("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)] + ) ) s2_stoics = {"Ba": 2, "O": 1, "F": 2} for test, expected in [(s1, s1_stoics), (s2, s2_stoics)]: with self.subTest(species=test.species): - self.assertEqual(SmactStructure._get_ele_stoics(test.species), expected) + self.assertEqual( + SmactStructure._get_ele_stoics(test.species), expected + ) - @unittest.skipUnless(os.environ.get("MPI_KEY"), "requires MPI key to be set.") + @unittest.skipUnless( + os.environ.get("MPI_KEY"), "requires MPI key to be set." + ) def test_from_mp(self): """Test downloading structures from materialsproject.org.""" # TODO Needs ensuring that the structure query gets the same @@ -258,11 +293,15 @@ def test_db_interface(self): self.fail(e) with self.subTest(msg="Getting structure from table."): - struct_list = self.db.get_structs(struct.composition(), self.TEST_TABLE) + struct_list = self.db.get_structs( + struct.composition(), self.TEST_TABLE + ) self.assertEqual(len(struct_list), 1) self.assertEqual(struct_list[0], struct) - struct_files = [os.path.join(files_dir, f"{x}.txt") for x in ["NaCl", "Fe"]] + struct_files = [ + os.path.join(files_dir, f"{x}.txt") for x in ["NaCl", "Fe"] + ] structs = [SmactStructure.from_file(fname) for fname in struct_files] with self.subTest(msg="Adding multiple structures to table."): @@ -293,7 +332,9 @@ def test_db_interface(self): [struct], ] - for spec, expected in zip(test_with_species_args, test_with_species_exp): + for spec, expected in zip( + test_with_species_args, test_with_species_exp + ): with self.subTest(msg=f"Retrieving species with {spec}"): self.assertEqual( self.db.get_with_species(spec, self.TEST_TABLE), expected @@ -321,7 +362,9 @@ def setUpClass(cls): """Set up the test initial structure and mutator.""" cls.test_struct = SmactStructure.from_file(TEST_POSCAR) - cls.test_mutator = CationMutator.from_json(lambda_json=TEST_LAMBDA_JSON) + cls.test_mutator = CationMutator.from_json( + lambda_json=TEST_LAMBDA_JSON + ) cls.test_pymatgen_mutator = CationMutator.from_json( lambda_json=None, alpha=lambda x, y: -5 ) @@ -358,12 +401,15 @@ def test_partition_func_Z(self): def test_pymatgen_lambda_import(self): """Test importing pymatgen lambda table.""" - self.assertIsInstance(self.test_pymatgen_mutator.lambda_tab, pd.DataFrame) + self.assertIsInstance( + self.test_pymatgen_mutator.lambda_tab, pd.DataFrame + ) def test_lambda_interface(self): """Test getting lambda values.""" test_cases = [ - itertools.permutations(x) for x in [("A", "B"), ("A", "C"), ("B", "C")] + itertools.permutations(x) + for x in [("A", "B"), ("A", "C"), ("B", "C")] ] expected = [0.5, -5.0, 0.3] @@ -372,7 +418,9 @@ def test_lambda_interface(self): for spec_comb in test_case: s1, s2 = spec_comb with self.subTest(s1=s1, s2=s2): - self.assertEqual(self.test_mutator.get_lambda(s1, s2), expectation) + self.assertEqual( + self.test_mutator.get_lambda(s1, s2), expectation + ) def test_ion_mutation(self): """Test mutating an ion of a SmactStructure.""" @@ -383,7 +431,9 @@ def test_ion_mutation(self): BaTiO3 = SmactStructure.from_file(ba_file) with self.subTest(s1="CaTiO3", s2="BaTiO3"): - mutation = self.test_mutator._mutate_structure(CaTiO3, "Ca2+", "Ba2+") + mutation = self.test_mutator._mutate_structure( + CaTiO3, "Ca2+", "Ba2+" + ) self.assertEqual(mutation, BaTiO3) na_file = os.path.join(files_dir, "NaCl.txt") @@ -416,7 +466,9 @@ def test_cond_sub_probs(self): ] test_df = pd.DataFrame(vals) - test_df: pd.DataFrame = test_df.pivot(index=0, columns=1, values=2) + test_df: pd.DataFrame = test_df.pivot( + index=0, columns=1, values=2 + ) # Slice to convert to series assert_series_equal(cond_sub_probs_test, test_df.iloc[0]) @@ -534,7 +586,9 @@ def test_prediction(self): with self.subTest(msg="Acquiring predictions"): try: predictions = list( - sp.predict_structs(test_specs, thresh=0.02, include_same=False) + sp.predict_structs( + test_specs, thresh=0.02, include_same=False + ) ) except Exception as e: self.fail(e)