From b3e29548beb07ef2bbd9e501eb4ae4014e475ba3 Mon Sep 17 00:00:00 2001 From: Jack Sundberg Date: Mon, 20 Jun 2022 12:35:11 -0400 Subject: [PATCH] fix psitestats tests --- matminer/featurizers/structure/sites.py | 4 +-- .../featurizers/structure/tests/test_sites.py | 35 ++++++++++--------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/matminer/featurizers/structure/sites.py b/matminer/featurizers/structure/sites.py index e91332f98..b993eead4 100644 --- a/matminer/featurizers/structure/sites.py +++ b/matminer/featurizers/structure/sites.py @@ -321,7 +321,7 @@ def featurize(self, s): if not s.is_ordered: raise ValueError("Disordered structure support not built yet") - if self.elements_ is None: + if not hasattr(self, "elements_") or self.elements_ is None: raise Exception("You must run 'fit' first!") output = [] @@ -369,7 +369,7 @@ def compute_pssf(self, s, e): return stats def feature_labels(self): - if self.elements_ is None: + if not hasattr(self, "elements_") or self.elements_ is None: raise Exception("You must run 'fit' first!") labels = [] diff --git a/matminer/featurizers/structure/tests/test_sites.py b/matminer/featurizers/structure/tests/test_sites.py index 0e81f1954..8124d313b 100644 --- a/matminer/featurizers/structure/tests/test_sites.py +++ b/matminer/featurizers/structure/tests/test_sites.py @@ -118,14 +118,19 @@ class PartialStructureSitesFeaturesTest(StructureFeaturesTest): def test_partialsitestatsfingerprint(self): # Test matrix. op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint", stats=None) + op_struct_fp.fit([self.diamond]) opvals = op_struct_fp.featurize(self.diamond) _ = op_struct_fp.feature_labels() self.assertAlmostEqual(opvals[10][0], 0.9995, places=7) self.assertAlmostEqual(opvals[10][1], 0.9995, places=7) + + op_struct_fp.fit([self.nacl]) opvals = op_struct_fp.featurize(self.nacl) self.assertAlmostEqual(opvals[18][0], 0.9995, places=7) self.assertAlmostEqual(opvals[18][1], 0.9995, places=7) + + op_struct_fp.fit([self.cscl]) opvals = op_struct_fp.featurize(self.cscl) self.assertAlmostEqual(opvals[22][0], 0.9995, places=7) self.assertAlmostEqual(opvals[22][1], 0.9995, places=7) @@ -158,57 +163,52 @@ def test_partialsitestatsfingerprint(self): stats=["mean"], covariance=True, ) - prop_fp.fit([self.diamond]) # Test the feature labels + prop_fp.fit([self.diamond]) labels = prop_fp.feature_labels() self.assertEqual(3, len(labels)) # Test a structure with all the same type (cov should be zero) + prop_fp.fit([self.diamond]) features = prop_fp.featurize(self.diamond) self.assertArrayAlmostEqual(features, [6, 12.0107, 0]) # Test a structure with only one atom (cov should be zero too) + prop_fp.fit([self.sc]) features = prop_fp.featurize(self.sc) self.assertArrayAlmostEqual([13, 26.9815386, 0], features) # Test a structure with nonzero covariance + prop_fp.fit([self.nacl]) features = prop_fp.featurize(self.nacl) - self.assertArrayAlmostEqual([14, 29.22138464, 37.38969216], features) - - # Test soap site featurizer - soap_fp = PartialsSiteStatsFingerprint.from_preset("SOAP_formation_energy") - soap_fp.fit([self.sc, self.diamond, self.nacl]) - feats = soap_fp.featurize(self.diamond) - self.assertEqual(len(feats), 9504) - self.assertAlmostEqual(feats[0], 0.4412608, places=5) - self.assertAlmostEqual(feats[1], 0.0) - self.assertAlmostEqual(np.sum(feats), 207.88194724, places=5) + self.assertArrayAlmostEqual([11, 22.9897693, np.nan, 17, 35.453, np.nan], features) def test_ward_prb_2017_lpd(self): """Test the local property difference attributes from Ward 2017""" f = PartialsSiteStatsFingerprint.from_preset("LocalPropertyDifference_ward-prb-2017") - f.fit([self.diamond]) # Test diamond + f.fit([self.diamond]) features = f.featurize(self.diamond) self.assertArrayAlmostEqual(features, [0] * (22 * 5)) features = f.featurize(self.diamond_no_oxi) self.assertArrayAlmostEqual(features, [0] * (22 * 5)) # Test CsCl + f.fit([self.cscl]) big_face_area = np.sqrt(3) * 3 / 2 * (2 / 4 / 4) small_face_area = 0.125 big_face_diff = 55 - 17 features = f.featurize(self.cscl) labels = f.feature_labels() - my_label = "mean local difference in Number" + my_label = "Cs mean local difference in Number" self.assertAlmostEqual( (8 * big_face_area * big_face_diff) / (8 * big_face_area + 6 * small_face_area), features[labels.index(my_label)], places=3, ) - my_label = "range local difference in Electronegativity" + my_label = "Cs range local difference in Electronegativity" self.assertAlmostEqual(0, features[labels.index(my_label)], places=3) def test_ward_prb_2017_efftcn(self): @@ -216,11 +216,12 @@ def test_ward_prb_2017_efftcn(self): f = PartialsSiteStatsFingerprint.from_preset("CoordinationNumber_ward-prb-2017") # Test Ni3Al + f.fit([self.ni3al]) features = f.featurize(self.ni3al) labels = f.feature_labels() - my_label = "mean CN_VoronoiNN" - self.assertAlmostEqual(12, features[labels.index(my_label)]) - self.assertArrayAlmostEqual([12, 12, 0, 12, 0], features) + self.assertAlmostEqual(12, features[labels.index("Al mean CN_VoronoiNN")]) + self.assertAlmostEqual(12, features[labels.index("Ni mean CN_VoronoiNN")]) + self.assertArrayAlmostEqual([12, 12, 0, 12, 0] * 2, features) if __name__ == "__main__":