diff --git a/pygmt/session_management.py b/pygmt/session_management.py index 4bb829835b9..750157679d4 100644 --- a/pygmt/session_management.py +++ b/pygmt/session_management.py @@ -1,7 +1,11 @@ """ Modern mode session management modules. """ +import os +import sys + from pygmt.clib import Session +from pygmt.helpers import unique_name def begin(): @@ -12,6 +16,10 @@ def begin(): Only meant to be used once for creating the global session. """ + # On Windows, need to set GMT_SESSION_NAME to a unique value + if sys.platform == "win32": + os.environ["GMT_SESSION_NAME"] = unique_name() + prefix = "pygmt-session" with Session() as lib: lib.call_module(module="begin", args=prefix) diff --git a/pygmt/tests/test_session_management.py b/pygmt/tests/test_session_management.py index 079c2c4e02c..544c5f037de 100644 --- a/pygmt/tests/test_session_management.py +++ b/pygmt/tests/test_session_management.py @@ -1,7 +1,10 @@ """ Test the session management modules. """ +import multiprocessing as mp import os +from importlib import reload +from pathlib import Path import pytest from pygmt.clib import Session @@ -57,3 +60,29 @@ def test_gmt_compat_6_is_applied(capsys): # Make sure no global "gmt.conf" in the current directory assert not os.path.exists("gmt.conf") begin() # Restart the global session + + +def _gmt_func_wrapper(figname): + """ + A wrapper for running PyGMT scripts with multiprocessing. + + Currently, we have to import pygmt and reload it in each process. Workaround from + https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875. + """ + import pygmt + + reload(pygmt) + fig = pygmt.Figure() + fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg") + fig.savefig(figname) + + +def test_session_multiprocessing(): + """ + Make sure that multiprocessing is supported if pygmt is re-imported. + """ + prefix = "test_session_multiprocessing" + with mp.Pool(2) as p: + p.map(_gmt_func_wrapper, [f"{prefix}-1.png", f"{prefix}-2.png"]) + Path(f"{prefix}-1.png").unlink() + Path(f"{prefix}-2.png").unlink()