Skip to content

Commit

Permalink
Merge pull request #809 from jacksund/main
Browse files Browse the repository at this point in the history
add partial site stats fingerprint
  • Loading branch information
ardunn authored Aug 12, 2022
2 parents 886524a + b3e2954 commit ead30fe
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 16 deletions.
171 changes: 156 additions & 15 deletions matminer/featurizers/structure/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Structure featurizers based on aggregating site features.
"""

import itertools
import numpy as np
from pymatgen.analysis.local_env import VoronoiNN
from pymatgen.core.periodic_table import Element, Specie

from matminer.featurizers.base import BaseFeaturizer
from matminer.featurizers.site import (
Expand Down Expand Up @@ -147,8 +149,8 @@ def citations(self):
def implementors(self):
return ["Nils E. R. Zimmermann", "Alireza Faghaninia", "Anubhav Jain", "Logan Ward", "Alex Dunn"]

@staticmethod
def from_preset(preset, **kwargs):
@classmethod
def from_preset(cls, preset, **kwargs):
"""
Create a SiteStatsFingerprint class according to a preset
Expand All @@ -158,37 +160,37 @@ def from_preset(preset, **kwargs):
"""

if preset == "SOAP_formation_energy":
return SiteStatsFingerprint(SOAP.from_preset("formation_energy"), **kwargs)
return cls(SOAP.from_preset("formation_energy"), **kwargs)

elif preset == "CrystalNNFingerprint_cn":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("cn", cation_anion=False), **kwargs)
return cls(CrystalNNFingerprint.from_preset("cn", cation_anion=False), **kwargs)

elif preset == "CrystalNNFingerprint_cn_cation_anion":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("cn", cation_anion=True), **kwargs)
return cls(CrystalNNFingerprint.from_preset("cn", cation_anion=True), **kwargs)

elif preset == "CrystalNNFingerprint_ops":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("ops", cation_anion=False), **kwargs)
return cls(CrystalNNFingerprint.from_preset("ops", cation_anion=False), **kwargs)

elif preset == "CrystalNNFingerprint_ops_cation_anion":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("ops", cation_anion=True), **kwargs)
return cls(CrystalNNFingerprint.from_preset("ops", cation_anion=True), **kwargs)

elif preset == "OPSiteFingerprint":
return SiteStatsFingerprint(OPSiteFingerprint(), **kwargs)
return cls(OPSiteFingerprint(), **kwargs)

elif preset == "LocalPropertyDifference_ward-prb-2017":
return SiteStatsFingerprint(
return cls(
LocalPropertyDifference.from_preset("ward-prb-2017"),
stats=["minimum", "maximum", "range", "mean", "avg_dev"],
)

elif preset == "CoordinationNumber_ward-prb-2017":
return SiteStatsFingerprint(
return cls(
CoordinationNumber(nn=VoronoiNN(weight="area"), use_weights="effective"),
stats=["minimum", "maximum", "range", "mean", "avg_dev"],
)

elif preset == "Composition-dejong2016_AD":
return SiteStatsFingerprint(
return cls(
LocalPropertyDifference(
properties=[
"Number",
Expand All @@ -204,7 +206,7 @@ def from_preset(preset, **kwargs):
)

elif preset == "Composition-dejong2016_SD":
return SiteStatsFingerprint(
return cls(
LocalPropertyDifference(
properties=[
"Number",
Expand All @@ -220,13 +222,13 @@ def from_preset(preset, **kwargs):
)

elif preset == "BondLength-dejong2016":
return SiteStatsFingerprint(
return cls(
AverageBondLength(VoronoiNN()),
stats=["holder_mean::%d" % d for d in range(-4, 4 + 1)] + ["std_dev", "geom_std_dev"],
)

elif preset == "BondAngle-dejong2016":
return SiteStatsFingerprint(
return cls(
AverageBondAngle(VoronoiNN()),
stats=["holder_mean::%d" % d for d in range(-4, 4 + 1)] + ["std_dev", "geom_std_dev"],
)
Expand All @@ -236,8 +238,147 @@ def from_preset(preset, **kwargs):
# One of the various Coordination Number presets:
# MinimumVIRENN, MinimumDistanceNN, JmolNN, VoronoiNN, etc.
try:
return SiteStatsFingerprint(CoordinationNumber.from_preset(preset), **kwargs)
return cls(CoordinationNumber.from_preset(preset), **kwargs)
except Exception:
pass

raise ValueError("Unrecognized preset!")


class PartialsSiteStatsFingerprint(SiteStatsFingerprint):
"""
Computes statistics of properties across all sites in a structure, and
breaks these down by element. This featurizer first uses a site featurizer
class (see site.py for options) to compute features of each site of a
specific element in a structure, and then computes features of the entire
structure by measuring statistics of each attribute.
Features:
- Returns each statistic of each site feature, broken down by element
"""

def __init__(
self,
site_featurizer,
stats=("mean", "std_dev"),
min_oxi=None,
max_oxi=None,
covariance=False,
include_elems=(),
exclude_elems=(),
):
"""
Args:
site_featurizer (BaseFeaturizer): a site-based featurizer
stats ([str]): list of weighted statistics to compute for each feature.
If stats is None, a list is returned for each features
that contains the calculated feature for each site in the
structure.
*Note for nth mode, stat must be 'n*_mode'; e.g. stat='2nd_mode'
min_oxi (int): minimum site oxidation state for inclusion (e.g.,
zero means metals/cations only)
max_oxi (int): maximum site oxidation state for inclusion
covariance (bool): Whether to compute the covariance of site features
"""

self.include_elems = list(include_elems)
self.exclude_elems = list(exclude_elems)
super().__init__(site_featurizer, stats, min_oxi, max_oxi, covariance)

def fit(self, X, y=None):
"""Define the list of elements to be included in the PRDF. By default,
the PRDF will include all of the elements in `X`
Args:
X: (numpy array nx1) structures used in the training set. Each entry
must be Pymatgen Structure objects.
y: *Not used*
fit_kwargs: *not used*
"""

# This method largely copies code from the partial-RDF fingerprint

# Initialize list with included elements
elements = [Element(e) for e in self.include_elems]

# Get all of elements that appear
for structure in X:
for element in structure.composition.elements:
if isinstance(element, Specie):
element = element.element # converts from Specie to Element object
if element not in elements and element.name not in self.exclude_elems:
elements.append(element)

# Store the elements
self.elements_ = [e.symbol for e in sorted(elements)]

def featurize(self, s):
"""
Get PSSF of the input structure.
Args:
s: Pymatgen Structure object.
Returns:
pssf: 1D array of each element's ssf
"""

if not s.is_ordered:
raise ValueError("Disordered structure support not built yet")
if not hasattr(self, "elements_") or self.elements_ is None:
raise Exception("You must run 'fit' first!")

output = []
for e in self.elements_:
pssf_stats = self.compute_pssf(s, e)
output.append(pssf_stats)

return np.hstack(output)

def compute_pssf(self, s, e):

# This code is extremely similar to super().featurize(). The key
# difference is that only one specific element is analyzed.

# Get each feature for each site
vals = [[] for t in self._site_labels]
for i, site in enumerate(s.sites):
if site.specie.symbol == e:
opvalstmp = self.site_featurizer.featurize(s, i)
for j, opval in enumerate(opvalstmp):
if opval is None:
vals[j].append(0.0)
else:
vals[j].append(opval)

# If the user does not request statistics, return the site features now
if self.stats is None:
return vals

# Compute the requested statistics
stats = []
for op in vals:
for stat in self.stats:
stats.append(PropertyStats().calc_stat(op, stat))

# If desired, compute covariances
if self.covariance:
if len(s) == 1:
stats.extend([0] * int(len(vals) * (len(vals) - 1) / 2))
else:
covar = np.cov(vals)
tri_ind = np.triu_indices(len(vals), 1)
stats.extend(covar[tri_ind].tolist())

return stats

def feature_labels(self):
if not hasattr(self, "elements_") or self.elements_ is None:
raise Exception("You must run 'fit' first!")

labels = []
for e in self.elements_:
e_labels = [f"{e} {l}" for l in super().feature_labels()]
for l in e_labels:
labels.append(l)

return labels

def implementors(self):
return ["Jack Sundberg"]
115 changes: 114 additions & 1 deletion matminer/featurizers/structure/tests/test_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import numpy as np

from matminer.featurizers.site import SiteElementalProperty
from matminer.featurizers.structure.sites import SiteStatsFingerprint
from matminer.featurizers.structure.sites import (
SiteStatsFingerprint,
PartialsSiteStatsFingerprint,
)
from matminer.featurizers.structure.tests.base import StructureFeaturesTest


Expand Down Expand Up @@ -111,5 +114,115 @@ def test_ward_prb_2017_efftcn(self):
self.assertArrayAlmostEqual([12, 12, 0, 12, 0], features)


class PartialStructureSitesFeaturesTest(StructureFeaturesTest):
def test_partialsitestatsfingerprint(self):
# Test matrix.
op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint", stats=None)

op_struct_fp.fit([self.diamond])
opvals = op_struct_fp.featurize(self.diamond)
_ = op_struct_fp.feature_labels()
self.assertAlmostEqual(opvals[10][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[10][1], 0.9995, places=7)

op_struct_fp.fit([self.nacl])
opvals = op_struct_fp.featurize(self.nacl)
self.assertAlmostEqual(opvals[18][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[18][1], 0.9995, places=7)

op_struct_fp.fit([self.cscl])
opvals = op_struct_fp.featurize(self.cscl)
self.assertAlmostEqual(opvals[22][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[22][1], 0.9995, places=7)

# Test stats.
op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint")
op_struct_fp.fit([self.diamond])
opvals = op_struct_fp.featurize(self.diamond)
self.assertAlmostEqual(opvals[0], 0.0005, places=7)
self.assertAlmostEqual(opvals[1], 0, places=7)
self.assertAlmostEqual(opvals[2], 0.0005, places=7)
self.assertAlmostEqual(opvals[3], 0.0, places=7)
self.assertAlmostEqual(opvals[4], 0.0005, places=7)
self.assertAlmostEqual(opvals[18], 0.0805, places=7)
self.assertAlmostEqual(opvals[20], 0.9995, places=7)
self.assertAlmostEqual(opvals[21], 0, places=7)
self.assertAlmostEqual(opvals[22], 0.0075, places=7)
self.assertAlmostEqual(opvals[24], 0.2355, places=7)
self.assertAlmostEqual(opvals[-1], 0.0, places=7)

# Test coordination number
cn_fp = PartialsSiteStatsFingerprint.from_preset("JmolNN", stats=("mean",))
cn_fp.fit([self.diamond])
cn_vals = cn_fp.featurize(self.diamond)
self.assertEqual(cn_vals[0], 4.0)

# Test the covariance
prop_fp = PartialsSiteStatsFingerprint(
SiteElementalProperty(properties=["Number", "AtomicWeight"]),
stats=["mean"],
covariance=True,
)

# Test the feature labels
prop_fp.fit([self.diamond])
labels = prop_fp.feature_labels()
self.assertEqual(3, len(labels))

# Test a structure with all the same type (cov should be zero)
prop_fp.fit([self.diamond])
features = prop_fp.featurize(self.diamond)
self.assertArrayAlmostEqual(features, [6, 12.0107, 0])

# Test a structure with only one atom (cov should be zero too)
prop_fp.fit([self.sc])
features = prop_fp.featurize(self.sc)
self.assertArrayAlmostEqual([13, 26.9815386, 0], features)

# Test a structure with nonzero covariance
prop_fp.fit([self.nacl])
features = prop_fp.featurize(self.nacl)
self.assertArrayAlmostEqual([11, 22.9897693, np.nan, 17, 35.453, np.nan], features)

def test_ward_prb_2017_lpd(self):
"""Test the local property difference attributes from Ward 2017"""
f = PartialsSiteStatsFingerprint.from_preset("LocalPropertyDifference_ward-prb-2017")

# Test diamond
f.fit([self.diamond])
features = f.featurize(self.diamond)
self.assertArrayAlmostEqual(features, [0] * (22 * 5))
features = f.featurize(self.diamond_no_oxi)
self.assertArrayAlmostEqual(features, [0] * (22 * 5))

# Test CsCl
f.fit([self.cscl])
big_face_area = np.sqrt(3) * 3 / 2 * (2 / 4 / 4)
small_face_area = 0.125
big_face_diff = 55 - 17
features = f.featurize(self.cscl)
labels = f.feature_labels()
my_label = "Cs mean local difference in Number"
self.assertAlmostEqual(
(8 * big_face_area * big_face_diff) / (8 * big_face_area + 6 * small_face_area),
features[labels.index(my_label)],
places=3,
)
my_label = "Cs range local difference in Electronegativity"
self.assertAlmostEqual(0, features[labels.index(my_label)], places=3)

def test_ward_prb_2017_efftcn(self):
"""Test the effective coordination number attributes of Ward 2017"""
f = PartialsSiteStatsFingerprint.from_preset("CoordinationNumber_ward-prb-2017")

# Test Ni3Al
f.fit([self.ni3al])
features = f.featurize(self.ni3al)
labels = f.feature_labels()
self.assertAlmostEqual(12, features[labels.index("Al mean CN_VoronoiNN")])
self.assertAlmostEqual(12, features[labels.index("Ni mean CN_VoronoiNN")])
self.assertArrayAlmostEqual([12, 12, 0, 12, 0] * 2, features)


if __name__ == "__main__":
unittest.main()

0 comments on commit ead30fe

Please sign in to comment.