Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add partial site stats fingerprint #809

Merged
merged 4 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 156 additions & 15 deletions matminer/featurizers/structure/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Structure featurizers based on aggregating site features.
"""

import itertools
import numpy as np
from pymatgen.analysis.local_env import VoronoiNN
from pymatgen.core.periodic_table import Element, Specie

from matminer.featurizers.base import BaseFeaturizer
from matminer.featurizers.site import (
Expand Down Expand Up @@ -147,8 +149,8 @@ def citations(self):
def implementors(self):
return ["Nils E. R. Zimmermann", "Alireza Faghaninia", "Anubhav Jain", "Logan Ward", "Alex Dunn"]

@staticmethod
def from_preset(preset, **kwargs):
@classmethod
def from_preset(cls, preset, **kwargs):
"""
Create a SiteStatsFingerprint class according to a preset

Expand All @@ -158,37 +160,37 @@ def from_preset(preset, **kwargs):
"""

if preset == "SOAP_formation_energy":
return SiteStatsFingerprint(SOAP.from_preset("formation_energy"), **kwargs)
return cls(SOAP.from_preset("formation_energy"), **kwargs)

elif preset == "CrystalNNFingerprint_cn":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("cn", cation_anion=False), **kwargs)
return cls(CrystalNNFingerprint.from_preset("cn", cation_anion=False), **kwargs)

elif preset == "CrystalNNFingerprint_cn_cation_anion":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("cn", cation_anion=True), **kwargs)
return cls(CrystalNNFingerprint.from_preset("cn", cation_anion=True), **kwargs)

elif preset == "CrystalNNFingerprint_ops":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("ops", cation_anion=False), **kwargs)
return cls(CrystalNNFingerprint.from_preset("ops", cation_anion=False), **kwargs)

elif preset == "CrystalNNFingerprint_ops_cation_anion":
return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("ops", cation_anion=True), **kwargs)
return cls(CrystalNNFingerprint.from_preset("ops", cation_anion=True), **kwargs)

elif preset == "OPSiteFingerprint":
return SiteStatsFingerprint(OPSiteFingerprint(), **kwargs)
return cls(OPSiteFingerprint(), **kwargs)

elif preset == "LocalPropertyDifference_ward-prb-2017":
return SiteStatsFingerprint(
return cls(
LocalPropertyDifference.from_preset("ward-prb-2017"),
stats=["minimum", "maximum", "range", "mean", "avg_dev"],
)

elif preset == "CoordinationNumber_ward-prb-2017":
return SiteStatsFingerprint(
return cls(
CoordinationNumber(nn=VoronoiNN(weight="area"), use_weights="effective"),
stats=["minimum", "maximum", "range", "mean", "avg_dev"],
)

elif preset == "Composition-dejong2016_AD":
return SiteStatsFingerprint(
return cls(
LocalPropertyDifference(
properties=[
"Number",
Expand All @@ -204,7 +206,7 @@ def from_preset(preset, **kwargs):
)

elif preset == "Composition-dejong2016_SD":
return SiteStatsFingerprint(
return cls(
LocalPropertyDifference(
properties=[
"Number",
Expand All @@ -220,13 +222,13 @@ def from_preset(preset, **kwargs):
)

elif preset == "BondLength-dejong2016":
return SiteStatsFingerprint(
return cls(
AverageBondLength(VoronoiNN()),
stats=["holder_mean::%d" % d for d in range(-4, 4 + 1)] + ["std_dev", "geom_std_dev"],
)

elif preset == "BondAngle-dejong2016":
return SiteStatsFingerprint(
return cls(
AverageBondAngle(VoronoiNN()),
stats=["holder_mean::%d" % d for d in range(-4, 4 + 1)] + ["std_dev", "geom_std_dev"],
)
Expand All @@ -236,8 +238,147 @@ def from_preset(preset, **kwargs):
# One of the various Coordination Number presets:
# MinimumVIRENN, MinimumDistanceNN, JmolNN, VoronoiNN, etc.
try:
return SiteStatsFingerprint(CoordinationNumber.from_preset(preset), **kwargs)
return cls(CoordinationNumber.from_preset(preset), **kwargs)
except Exception:
pass

raise ValueError("Unrecognized preset!")


class PartialsSiteStatsFingerprint(SiteStatsFingerprint):
"""
Computes statistics of properties across all sites in a structure, and
breaks these down by element. This featurizer first uses a site featurizer
class (see site.py for options) to compute features of each site of a
specific element in a structure, and then computes features of the entire
structure by measuring statistics of each attribute.
Features:
- Returns each statistic of each site feature, broken down by element
"""

def __init__(
self,
site_featurizer,
stats=("mean", "std_dev"),
min_oxi=None,
max_oxi=None,
covariance=False,
include_elems=(),
exclude_elems=(),
):
"""
Args:
site_featurizer (BaseFeaturizer): a site-based featurizer
stats ([str]): list of weighted statistics to compute for each feature.
If stats is None, a list is returned for each features
that contains the calculated feature for each site in the
structure.
*Note for nth mode, stat must be 'n*_mode'; e.g. stat='2nd_mode'
min_oxi (int): minimum site oxidation state for inclusion (e.g.,
zero means metals/cations only)
max_oxi (int): maximum site oxidation state for inclusion
covariance (bool): Whether to compute the covariance of site features
"""

self.include_elems = list(include_elems)
self.exclude_elems = list(exclude_elems)
super().__init__(site_featurizer, stats, min_oxi, max_oxi, covariance)

def fit(self, X, y=None):
"""Define the list of elements to be included in the PRDF. By default,
the PRDF will include all of the elements in `X`
Args:
X: (numpy array nx1) structures used in the training set. Each entry
must be Pymatgen Structure objects.
y: *Not used*
fit_kwargs: *not used*
"""

# This method largely copies code from the partial-RDF fingerprint

# Initialize list with included elements
elements = [Element(e) for e in self.include_elems]

# Get all of elements that appear
for structure in X:
for element in structure.composition.elements:
if isinstance(element, Specie):
element = element.element # converts from Specie to Element object
if element not in elements and element.name not in self.exclude_elems:
elements.append(element)

# Store the elements
self.elements_ = [e.symbol for e in sorted(elements)]

def featurize(self, s):
"""
Get PSSF of the input structure.
Args:
s: Pymatgen Structure object.
Returns:
pssf: 1D array of each element's ssf
"""

if not s.is_ordered:
raise ValueError("Disordered structure support not built yet")
if not hasattr(self, "elements_") or self.elements_ is None:
raise Exception("You must run 'fit' first!")

output = []
for e in self.elements_:
pssf_stats = self.compute_pssf(s, e)
output.append(pssf_stats)

return np.hstack(output)

def compute_pssf(self, s, e):

# This code is extremely similar to super().featurize(). The key
# difference is that only one specific element is analyzed.

# Get each feature for each site
vals = [[] for t in self._site_labels]
for i, site in enumerate(s.sites):
if site.specie.symbol == e:
opvalstmp = self.site_featurizer.featurize(s, i)
for j, opval in enumerate(opvalstmp):
if opval is None:
vals[j].append(0.0)
else:
vals[j].append(opval)

# If the user does not request statistics, return the site features now
if self.stats is None:
return vals

# Compute the requested statistics
stats = []
for op in vals:
for stat in self.stats:
stats.append(PropertyStats().calc_stat(op, stat))

# If desired, compute covariances
if self.covariance:
if len(s) == 1:
stats.extend([0] * int(len(vals) * (len(vals) - 1) / 2))
else:
covar = np.cov(vals)
tri_ind = np.triu_indices(len(vals), 1)
stats.extend(covar[tri_ind].tolist())

return stats

def feature_labels(self):
if not hasattr(self, "elements_") or self.elements_ is None:
raise Exception("You must run 'fit' first!")

labels = []
for e in self.elements_:
e_labels = [f"{e} {l}" for l in super().feature_labels()]
for l in e_labels:
labels.append(l)

return labels

def implementors(self):
return ["Jack Sundberg"]
115 changes: 114 additions & 1 deletion matminer/featurizers/structure/tests/test_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import numpy as np

from matminer.featurizers.site import SiteElementalProperty
from matminer.featurizers.structure.sites import SiteStatsFingerprint
from matminer.featurizers.structure.sites import (
SiteStatsFingerprint,
PartialsSiteStatsFingerprint,
)
from matminer.featurizers.structure.tests.base import StructureFeaturesTest


Expand Down Expand Up @@ -111,5 +114,115 @@ def test_ward_prb_2017_efftcn(self):
self.assertArrayAlmostEqual([12, 12, 0, 12, 0], features)


class PartialStructureSitesFeaturesTest(StructureFeaturesTest):
def test_partialsitestatsfingerprint(self):
# Test matrix.
op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint", stats=None)

op_struct_fp.fit([self.diamond])
opvals = op_struct_fp.featurize(self.diamond)
_ = op_struct_fp.feature_labels()
self.assertAlmostEqual(opvals[10][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[10][1], 0.9995, places=7)

op_struct_fp.fit([self.nacl])
opvals = op_struct_fp.featurize(self.nacl)
self.assertAlmostEqual(opvals[18][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[18][1], 0.9995, places=7)

op_struct_fp.fit([self.cscl])
opvals = op_struct_fp.featurize(self.cscl)
self.assertAlmostEqual(opvals[22][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[22][1], 0.9995, places=7)

# Test stats.
op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint")
op_struct_fp.fit([self.diamond])
opvals = op_struct_fp.featurize(self.diamond)
self.assertAlmostEqual(opvals[0], 0.0005, places=7)
self.assertAlmostEqual(opvals[1], 0, places=7)
self.assertAlmostEqual(opvals[2], 0.0005, places=7)
self.assertAlmostEqual(opvals[3], 0.0, places=7)
self.assertAlmostEqual(opvals[4], 0.0005, places=7)
self.assertAlmostEqual(opvals[18], 0.0805, places=7)
self.assertAlmostEqual(opvals[20], 0.9995, places=7)
self.assertAlmostEqual(opvals[21], 0, places=7)
self.assertAlmostEqual(opvals[22], 0.0075, places=7)
self.assertAlmostEqual(opvals[24], 0.2355, places=7)
self.assertAlmostEqual(opvals[-1], 0.0, places=7)

# Test coordination number
cn_fp = PartialsSiteStatsFingerprint.from_preset("JmolNN", stats=("mean",))
cn_fp.fit([self.diamond])
cn_vals = cn_fp.featurize(self.diamond)
self.assertEqual(cn_vals[0], 4.0)

# Test the covariance
prop_fp = PartialsSiteStatsFingerprint(
SiteElementalProperty(properties=["Number", "AtomicWeight"]),
stats=["mean"],
covariance=True,
)

# Test the feature labels
prop_fp.fit([self.diamond])
labels = prop_fp.feature_labels()
self.assertEqual(3, len(labels))

# Test a structure with all the same type (cov should be zero)
prop_fp.fit([self.diamond])
features = prop_fp.featurize(self.diamond)
self.assertArrayAlmostEqual(features, [6, 12.0107, 0])

# Test a structure with only one atom (cov should be zero too)
prop_fp.fit([self.sc])
features = prop_fp.featurize(self.sc)
self.assertArrayAlmostEqual([13, 26.9815386, 0], features)

# Test a structure with nonzero covariance
prop_fp.fit([self.nacl])
features = prop_fp.featurize(self.nacl)
self.assertArrayAlmostEqual([11, 22.9897693, np.nan, 17, 35.453, np.nan], features)

def test_ward_prb_2017_lpd(self):
"""Test the local property difference attributes from Ward 2017"""
f = PartialsSiteStatsFingerprint.from_preset("LocalPropertyDifference_ward-prb-2017")

# Test diamond
f.fit([self.diamond])
features = f.featurize(self.diamond)
self.assertArrayAlmostEqual(features, [0] * (22 * 5))
features = f.featurize(self.diamond_no_oxi)
self.assertArrayAlmostEqual(features, [0] * (22 * 5))

# Test CsCl
f.fit([self.cscl])
big_face_area = np.sqrt(3) * 3 / 2 * (2 / 4 / 4)
small_face_area = 0.125
big_face_diff = 55 - 17
features = f.featurize(self.cscl)
labels = f.feature_labels()
my_label = "Cs mean local difference in Number"
self.assertAlmostEqual(
(8 * big_face_area * big_face_diff) / (8 * big_face_area + 6 * small_face_area),
features[labels.index(my_label)],
places=3,
)
my_label = "Cs range local difference in Electronegativity"
self.assertAlmostEqual(0, features[labels.index(my_label)], places=3)

def test_ward_prb_2017_efftcn(self):
"""Test the effective coordination number attributes of Ward 2017"""
f = PartialsSiteStatsFingerprint.from_preset("CoordinationNumber_ward-prb-2017")

# Test Ni3Al
f.fit([self.ni3al])
features = f.featurize(self.ni3al)
labels = f.feature_labels()
self.assertAlmostEqual(12, features[labels.index("Al mean CN_VoronoiNN")])
self.assertAlmostEqual(12, features[labels.index("Ni mean CN_VoronoiNN")])
self.assertArrayAlmostEqual([12, 12, 0, 12, 0] * 2, features)


if __name__ == "__main__":
unittest.main()