Skip to content

Commit

Permalink
Refactor backend handling to separate function _get_backend to use …
Browse files Browse the repository at this point in the history
…with `doped`
  • Loading branch information
kavanase committed Aug 29, 2023
1 parent 4582b42 commit 009dbdb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 39 deletions.
30 changes: 18 additions & 12 deletions shakenbreak/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def _install_custom_font():
warnings.warn(warning_msg)


def _get_backend(save_format: str) -> str:
"""Try use pycairo as backend if installed, and save_format is pdf."""
backend = None
if "pdf" in save_format:
try:
import cairo

backend = "cairo"
except ImportError:
warnings.warn(
"pycairo not installed. Defaulting to matplotlib's pdf backend, so default "
"ShakeNBreak fonts may not be used – try setting `save_format` to 'png' or "
"`pip install pycairo` if you want ShakeNBreak's default font."
)
return backend


# Helper functions for formatting plots
def _verify_data_directories_exist(
output_path: str,
Expand Down Expand Up @@ -501,18 +518,7 @@ def _save_plot(
)

# use pycairo as backend if installed and save_format is pdf:
backend = None
if "pdf" in save_format:
try:
import cairo

backend = "cairo"
except ImportError:
warnings.warn(
"pycairo not installed. Defaulting to matplotlib's pdf backend, so default "
"ShakeNBreak fonts may not be used – try setting `save_format` to 'png' or "
"`pip install pycairo` if you want ShakeNBreak's default font."
)
backend = _get_backend(save_format)

fig.savefig(
plot_filepath,
Expand Down
27 changes: 0 additions & 27 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def test_snb_generate(self):
"""Implicitly, the `snb-generate` tests also test the functionality of
`input.identify_defect()`
"""
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
runner = CliRunner()
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
result = runner.invoke(
snb,
[
Expand All @@ -244,7 +242,6 @@ def test_snb_generate(self):
],
catch_exceptions=False,
)
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
self.assertEqual(result.exit_code, 0)
self.assertIn(
f"Auto site-matching identified {self.VASP_CDTE_DATA_DIR}/CdTe_V_Cd_POSCAR "
Expand Down Expand Up @@ -298,23 +295,18 @@ def test_snb_generate(self):
self.V_Cd_minus0pt5_struc_rattled,
)

print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
kpoints = Kpoints.from_file(f"{V_Cd_Bond_Distortion_folder}/KPOINTS")
self.assertEqual(kpoints.kpts, [[1, 1, 1]])

print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
if _potcars_available():
print(f"Hre: Line: {inspect.currentframe().f_lineno}") # print current line number
assert filecmp.cmp(f"{V_Cd_Bond_Distortion_folder}/INCAR", self.V_Cd_INCAR_file)

# check if POTCARs have been written:
potcar = Potcar.from_file(f"{V_Cd_Bond_Distortion_folder}/POTCAR")
assert set(potcar.as_dict()["symbols"]) == {"Cd", "Te"}
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number

# Test recognises distortion_metadata.json:
if_present_rm(f"{defect_name}_0") # but distortion_metadata.json still present
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
result = runner.invoke(
snb,
[
Expand All @@ -328,7 +320,6 @@ def test_snb_generate(self):
],
catch_exceptions=False,
) # non-verbose this time
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
self.assertEqual(result.exit_code, 0)
self.assertNotIn(
"Auto site-matching identified"
Expand Down Expand Up @@ -375,9 +366,7 @@ def test_snb_generate(self):
)

# test defect_index option:
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
self.tearDown()
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
result = runner.invoke(
snb,
[
Expand All @@ -393,7 +382,6 @@ def test_snb_generate(self):
"-v",
],
)
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
self.assertEqual(result.exit_code, 0)
self.assertNotIn("Auto site-matching", result.output)
self.assertIn("Oxidation states were not explicitly set", result.output)
Expand Down Expand Up @@ -467,14 +455,11 @@ def test_snb_generate(self):
with open("distortion_metadata.json", "r") as metadata_file:
metadata = json.load(metadata_file)
np.testing.assert_equal(metadata, wrong_site_V_Cd_dict)
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number

# test warning with defect_coords option but wrong site: (matches Cd site in bulk)
# using Int_Cd because V_Cd is at (0,0,0) so fractional and Cartesian coordinates the same
self.tearDown()
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
with warnings.catch_warnings(record=True) as w:
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
result = runner.invoke(
snb,
[
Expand All @@ -493,7 +478,6 @@ def test_snb_generate(self):
],
catch_exceptions=False,
)
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
self.assertEqual(result.exit_code, 0)
warning_message = (
"Coordinates (0.0, 0.0, 0.0) were specified for (auto-determined) interstitial "
Expand All @@ -520,9 +504,7 @@ def test_snb_generate(self):
)

# test defect_coords working even when slightly off correct site
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
self.tearDown()
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
with warnings.catch_warnings(record=True) as w:
result = runner.invoke(
snb,
Expand All @@ -541,11 +523,8 @@ def test_snb_generate(self):
"-v",
],
)
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
self.assertEqual(result.exit_code, 0)
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
if w:
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
# Check no problems in identifying the defect site
self.assertFalse(
any(str(warning.message) == warning_message for warning in w)
Expand All @@ -566,12 +545,10 @@ def test_snb_generate(self):
Structure.from_file(f"{defect_name}_0/Bond_Distortion_-60.0%/POSCAR"),
self.Int_Cd_2_minus0pt6_struc_rattled,
)
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number

# test defect_coords working even when significantly off (~2.2 Å) correct site,
# with rattled bulk
self.tearDown()
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
with warnings.catch_warnings(record=True) as w:
rattled_bulk = rattle(
self.CdTe_bulk_struc, stdev=0.25, d_min=2.25
Expand Down Expand Up @@ -618,7 +595,6 @@ def test_snb_generate(self):

# test defect_coords working even when slightly off correct site with V_Cd and rattled bulk
self.tearDown()
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
with warnings.catch_warnings(record=True) as w:
rattled_bulk = rattle(
self.CdTe_bulk_struc, stdev=0.25, d_min=2.25
Expand Down Expand Up @@ -756,7 +732,6 @@ def test_snb_generate(self):
np.testing.assert_equal(metadata, spec_coords_V_Cd_dict)

# test defect ID with tricky DX centre defect
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
result = runner.invoke(
snb,
[
Expand Down Expand Up @@ -808,7 +783,6 @@ def test_snb_generate(self):

# test padding functionality:
# default padding = 1
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number
result = runner.invoke(
snb,
[
Expand Down Expand Up @@ -876,7 +850,6 @@ def test_snb_generate(self):
self.assertTrue(os.path.exists(f"{defect_name}_-6"))
self.assertFalse(os.path.exists(f"{defect_name}_+5"))
self.assertFalse(os.path.exists(f"{defect_name}_-7"))
print(f"Line: {inspect.currentframe().f_lineno}") # print current line number

def test_snb_generate_config(self):
# test config file:
Expand Down

0 comments on commit 009dbdb

Please sign in to comment.