Skip to content

Commit

Permalink
Add hash for BalancedReaction (#3676)
Browse files Browse the repository at this point in the history
* add hash for BalancedReaction

* add types to BalancedReaction.__init__

* remove pointless self._els = all_reactants.elements (self._els overwritten later)

* add TestBalancedReaction.test_hash

* convert reactants_coeffs, products_coeffs keys to Composition if needed

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
DanielYang59 and janosh authored Mar 7, 2024
1 parent a4fbeeb commit 322c924
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pymatgen/analysis/pourbaix_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __getattr__(self, attr):
# Attributes that are weighted averages of entry attributes
if attr in ["energy", "npH", "nH2O", "nPhi", "conc_term", "composition", "uncorrected_energy", "elements"]:
# TODO: Composition could be changed for compat with sum
start = Composition({}) if attr == "composition" else 0
start = Composition() if attr == "composition" else 0
weighted_values = (getattr(entry, attr) * weight for entry, weight in zip(self.entry_list, self.weights))
return sum(weighted_values, start)

Expand Down
41 changes: 28 additions & 13 deletions pymatgen/analysis/reaction_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import re
from itertools import chain, combinations
from typing import TYPE_CHECKING, no_type_check

import numpy as np
from monty.fractions import gcd_float
Expand All @@ -14,6 +15,11 @@
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import ComputedEntry

if TYPE_CHECKING:
from collections.abc import Mapping

from pymatgen.util.typing import CompositionLike

__author__ = "Shyue Ping Ong, Anubhav Jain"
__copyright__ = "Copyright 2011, The Materials Project"
__version__ = "2.0"
Expand All @@ -32,36 +38,44 @@ class BalancedReaction(MSONable):
# Tolerance for determining if a particular component fraction is > 0.
TOLERANCE = 1e-6

def __init__(self, reactants_coeffs, products_coeffs):
@no_type_check
def __init__(
self,
reactants_coeffs: Mapping[CompositionLike, int | float],
products_coeffs: Mapping[CompositionLike, int | float],
) -> None:
"""
Reactants and products to be specified as dict of {Composition: coeff}.
Args:
reactants_coeffs (dict[Composition, float]): Reactants as dict of {Composition: amt}.
products_coeffs (dict[Composition, float]): Products as dict of {Composition: amt}.
"""
# convert to Composition if necessary
reactants_coeffs = {Composition(comp): coeff for comp, coeff in reactants_coeffs.items()}
products_coeffs = {Composition(comp): coeff for comp, coeff in products_coeffs.items()}

# sum reactants and products
all_reactants = sum((k * v for k, v in reactants_coeffs.items()), Composition({}))
all_products = sum((k * v for k, v in products_coeffs.items()), Composition({}))
all_reactants = sum((comp * coeff for comp, coeff in reactants_coeffs.items()), Composition())

all_products = sum((comp * coeff for comp, coeff in products_coeffs.items()), Composition())

if not all_reactants.almost_equals(all_products, rtol=0, atol=self.TOLERANCE):
raise ReactionError("Reaction is unbalanced!")

self._els = all_reactants.elements

self.reactants_coeffs = reactants_coeffs
self.products_coeffs = products_coeffs
self.reactants_coeffs: dict = reactants_coeffs
self.products_coeffs: dict = products_coeffs

# calculate net reaction coefficients
self._coeffs = []
self._els = []
self._all_comp = []
self._coeffs: list[float] = []
self._els: list[str] = []
self._all_comp: list[Composition] = []
for key in {*reactants_coeffs, *products_coeffs}:
coeff = products_coeffs.get(key, 0) - reactants_coeffs.get(key, 0)

if abs(coeff) > self.TOLERANCE:
self._all_comp.append(key)
self._coeffs.append(coeff)
self._all_comp += [key]
self._coeffs += [coeff]

def calculate_energy(self, energies):
"""
Expand Down Expand Up @@ -171,7 +185,8 @@ def __eq__(self, other: object) -> bool:
return True

def __hash__(self) -> int:
return 7
# Necessity for hash method is unclear (see gh-3673)
return hash((frozenset(self.reactants_coeffs.items()), frozenset(self.products_coeffs.items())))

@classmethod
def _str_from_formulas(cls, coeffs, formulas):
Expand Down
22 changes: 11 additions & 11 deletions pymatgen/apps/battery/conversion_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,14 @@ def from_steps(cls, step1, step2, normalization_els, framework_formula):
* 1000
* working_ion_valence
)
licomp = Composition(working_ion)
li_comp = Composition(working_ion)
prev_rxn = step1["reaction"]
reactants = {comp: abs(prev_rxn.get_coeff(comp)) for comp in prev_rxn.products if comp != licomp}
reactants = {comp: abs(prev_rxn.get_coeff(comp)) for comp in prev_rxn.products if comp != li_comp}

curr_rxn = step2["reaction"]
products = {comp: abs(curr_rxn.get_coeff(comp)) for comp in curr_rxn.products if comp != licomp}
products = {comp: abs(curr_rxn.get_coeff(comp)) for comp in curr_rxn.products if comp != li_comp}

reactants[licomp] = step2["evolution"] - step1["evolution"]
reactants[li_comp] = step2["evolution"] - step1["evolution"]

rxn = BalancedReaction(reactants, products)

Expand All @@ -318,30 +318,30 @@ def from_steps(cls, step1, step2, normalization_els, framework_formula):
break

prev_mass_dischg = (
sum(prev_rxn.all_comp[i].weight * abs(prev_rxn.coeffs[i]) for i in range(len(prev_rxn.all_comp))) / 2
sum(prev_rxn.all_comp[idx].weight * abs(prev_rxn.coeffs[idx]) for idx in range(len(prev_rxn.all_comp))) / 2
)
vol_charge = sum(
abs(prev_rxn.get_coeff(e.composition)) * e.structure.volume
for e in step1["entries"]
if e.reduced_formula != working_ion
)
mass_discharge = (
sum(curr_rxn.all_comp[i].weight * abs(curr_rxn.coeffs[i]) for i in range(len(curr_rxn.all_comp))) / 2
sum(curr_rxn.all_comp[idx].weight * abs(curr_rxn.coeffs[idx]) for idx in range(len(curr_rxn.all_comp))) / 2
)
mass_charge = prev_mass_dischg
vol_discharge = sum(
abs(curr_rxn.get_coeff(e.composition)) * e.structure.volume
for e in step2["entries"]
if e.reduced_formula != working_ion
abs(curr_rxn.get_coeff(entry.composition)) * entry.structure.volume
for entry in step2["entries"]
if entry.reduced_formula != working_ion
)

total_comp = Composition({})
total_comp = Composition()
for comp in prev_rxn.products:
if comp.reduced_formula != working_ion:
total_comp += comp * abs(prev_rxn.get_coeff(comp))
frac_charge = total_comp.get_atomic_fraction(Element(working_ion))

total_comp = Composition({})
total_comp = Composition()
for comp in curr_rxn.products:
if comp.reduced_formula != working_ion:
total_comp += comp * abs(curr_rxn.get_coeff(comp))
Expand Down
37 changes: 19 additions & 18 deletions tests/analysis/test_reaction_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,49 +288,50 @@ def test_underdetermined_reactants(self):


class TestBalancedReaction(unittest.TestCase):
def setUp(self) -> None:
rct = {"K2SO4": 3, "Na2S": 1, "Li": 24}
prod = {"KNaS": 2, "K2S": 2, "Li2O": 12}
self.rxn = BalancedReaction(rct, prod)

def test_init(self):
rct = {Composition("K2SO4"): 3, Composition("Na2S"): 1, Composition("Li"): 24}
prod = {Composition("KNaS"): 2, Composition("K2S"): 2, Composition("Li2O"): 12}
rxn = BalancedReaction(rct, prod)
assert str(rxn) == "24 Li + Na2S + 3 K2SO4 -> 2 KNaS + 2 K2S + 12 Li2O"
assert str(self.rxn) == "24 Li + Na2S + 3 K2SO4 -> 2 KNaS + 2 K2S + 12 Li2O"

# Test unbalanced exception
rct = {Composition("K2SO4"): 1, Composition("Na2S"): 1, Composition("Li"): 24}
prod = {Composition("KNaS"): 2, Composition("K2S"): 2, Composition("Li2O"): 12}
rct = {"K2SO4": 1, "Na2S": 1, "Li": 24}
prod = {"KNaS": 2, "K2S": 2, "Li2O": 12}
with pytest.raises(ReactionError, match="Reaction is unbalanced"):
BalancedReaction(rct, prod)

def test_as_from_dict(self):
rct = {Composition("K2SO4"): 3, Composition("Na2S"): 1, Composition("Li"): 24}
prod = {Composition("KNaS"): 2, Composition("K2S"): 2, Composition("Li2O"): 12}
rct = {"K2SO4": 3, "Na2S": 1, "Li": 24}
prod = {"KNaS": 2, "K2S": 2, "Li2O": 12}
rxn = BalancedReaction(rct, prod)
dct = rxn.as_dict()
new_rxn = BalancedReaction.from_dict(dct)
for comp in new_rxn.all_comp:
assert new_rxn.get_coeff(comp) == rxn.get_coeff(comp)

def test_from_str(self):
rxn = BalancedReaction({Composition("Li"): 4, Composition("O2"): 1}, {Composition("Li2O"): 2})
rxn = BalancedReaction({"Li": 4, "O2": 1}, {"Li2O": 2})
assert rxn == BalancedReaction.from_str("4 Li + O2 -> 2Li2O")

rxn = BalancedReaction(
{Composition("Li(NiO2)3"): 1},
{
Composition("O2"): 0.5,
Composition("Li(NiO2)2"): 1,
Composition("NiO"): 1,
},
{"Li(NiO2)3": 1},
{"O2": 0.5, "Li(NiO2)2": 1, "NiO": 1},
)

assert rxn == BalancedReaction.from_str("1.000 Li(NiO2)3 -> 0.500 O2 + 1.000 Li(NiO2)2 + 1.000 NiO")

def test_remove_spectator_species(self):
rxn = BalancedReaction(
{Composition("Li"): 4, Composition("O2"): 1, Composition("Na"): 1},
{Composition("Li2O"): 2, Composition("Na"): 1},
{"Li": 4, "O2": 1, "Na": 1},
{"Li2O": 2, "Na": 1},
)

assert Composition("Na") not in rxn.all_comp
assert "Na" not in rxn.all_comp

def test_hash(self):
assert hash(self.rxn) == 4774511606373046513


class TestComputedReaction(unittest.TestCase):
Expand Down

0 comments on commit 322c924

Please sign in to comment.