diff --git a/package/CHANGELOG b/package/CHANGELOG index f2264922088..aecde6e6468 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -17,7 +17,8 @@ The rules for this file: ??/??/?? IAlibay, HeetVekariya, marinegor, lilyminium, RMeli, ljwoods2, aditya292002, pstaerk, PicoCentauri, BFedder, tyler.je.reddy, SampurnaM, leonwehrhan, kainszs, orionarcher, - yuxuanzhuang, PythonFZ, laksh-krishna-sharma, orbeckst, MattTDavies + yuxuanzhuang, PythonFZ, laksh-krishna-sharma, orbeckst, MattTDavies, + talagayev * 2.8.0 @@ -55,6 +56,7 @@ Fixes Enhancements * Introduce parallelization API to `AnalysisBase` and to `analysis.rms.RMSD` class (Issue #4158, PR #4304) + * Enables parallelization for analysis.gnm.GNMAnalysis (Issue #4672) * explicitly mark `analysis.pca.PCA` as not parallelizable (Issue #4680) * Improve error message for `AtomGroup.unwrap()` when bonds are not present.(Issue #4436, PR #4642) * Add `analysis.DSSP` module for protein secondary structure assignment, based on [pydssp](https://github.com/ShintaroMinami/PyDSSP) diff --git a/package/MDAnalysis/analysis/gnm.py b/package/MDAnalysis/analysis/gnm.py index 86a62fa7b9f..510fb887d01 100644 --- a/package/MDAnalysis/analysis/gnm.py +++ b/package/MDAnalysis/analysis/gnm.py @@ -92,7 +92,7 @@ import numpy as np -from .base import AnalysisBase +from .base import AnalysisBase, ResultsGroup from MDAnalysis.analysis.base import Results @@ -245,8 +245,19 @@ class GNMAnalysis(AnalysisBase): Use :class:`~MDAnalysis.analysis.AnalysisBase` as parent class and store results as attributes ``times``, ``eigenvalues`` and ``eigenvectors`` of the ``results`` attribute. + + .. versionchanged:: 2.8.0 + Enabled **parallel execution** with the ``multiprocessing`` and ``dask`` + backends; use the new method :meth:`get_supported_backends` to see all + supported backends. """ + _analysis_algorithm_is_parallelizable = True + + @classmethod + def get_supported_backends(cls): + return ("serial", "multiprocessing", "dask") + def __init__(self, universe, select='protein and name CA', @@ -348,6 +359,15 @@ def _conclude(self): self.results.eigenvalues = np.asarray(self.results.eigenvalues) self.results.eigenvectors = np.asarray(self.results.eigenvectors) + def _get_aggregator(self): + return ResultsGroup( + lookup={ + "eigenvectors": ResultsGroup.ndarray_hstack, + "eigenvalues": ResultsGroup.ndarray_hstack, + "times": ResultsGroup.ndarray_hstack, + } + ) + class closeContactGNMAnalysis(GNMAnalysis): r"""GNMAnalysis only using close contacts. diff --git a/testsuite/MDAnalysisTests/analysis/conftest.py b/testsuite/MDAnalysisTests/analysis/conftest.py index 55bae7e6bd8..75d62284b7b 100644 --- a/testsuite/MDAnalysisTests/analysis/conftest.py +++ b/testsuite/MDAnalysisTests/analysis/conftest.py @@ -8,6 +8,7 @@ ) from MDAnalysis.analysis.rms import RMSD, RMSF from MDAnalysis.lib.util import is_installed +from MDAnalysis.analysis.gnm import GNMAnalysis def params_for_cls(cls, exclude: list[str] = None): @@ -87,3 +88,8 @@ def client_RMSD(request): @pytest.fixture(scope='module', params=params_for_cls(RMSF)) def client_RMSF(request): return request.param + + +@pytest.fixture(scope='module', params=params_for_cls(GNMAnalysis)) +def client_GNMAnalysis(request): + return request.param diff --git a/testsuite/MDAnalysisTests/analysis/test_gnm.py b/testsuite/MDAnalysisTests/analysis/test_gnm.py index 6521c08eb86..d8a547a5428 100644 --- a/testsuite/MDAnalysisTests/analysis/test_gnm.py +++ b/testsuite/MDAnalysisTests/analysis/test_gnm.py @@ -38,10 +38,10 @@ def universe(): return mda.Universe(GRO, XTC) -def test_gnm(universe, tmpdir): +def test_gnm(universe, tmpdir, client_GNMAnalysis): output = os.path.join(str(tmpdir), 'output.txt') gnm = mda.analysis.gnm.GNMAnalysis(universe, ReportVector=output) - gnm.run() + gnm.run(**client_GNMAnalysis) result = gnm.results assert len(result.times) == 10 assert_almost_equal(gnm.results.times, np.arange(0, 1000, 100), decimal=4) @@ -51,9 +51,9 @@ def test_gnm(universe, tmpdir): 4.2058769e-15, 3.9839431e-15]) -def test_gnm_run_step(universe): +def test_gnm_run_step(universe, client_GNMAnalysis): gnm = mda.analysis.gnm.GNMAnalysis(universe) - gnm.run(step=3) + gnm.run(step=3, **client_GNMAnalysis) result = gnm.results assert len(result.times) == 4 assert_almost_equal(gnm.results.times, np.arange(0, 1200, 300), decimal=4) @@ -88,9 +88,9 @@ def test_gnm_SVD_fail(universe): mda.analysis.gnm.GNMAnalysis(universe).run(stop=1) -def test_closeContactGNMAnalysis(universe): +def test_closeContactGNMAnalysis(universe, client_GNMAnalysis): gnm = mda.analysis.gnm.closeContactGNMAnalysis(universe, weights="size") - gnm.run(stop=2) + gnm.run(stop=2, **client_GNMAnalysis) result = gnm.results assert len(result.times) == 2 assert_almost_equal(gnm.results.times, (0, 100), decimal=4) @@ -114,9 +114,9 @@ def test_closeContactGNMAnalysis(universe): 0.0, 0.0, -2.263157894736841, -0.24333213169614382]) -def test_closeContactGNMAnalysis_weights_None(universe): +def test_closeContactGNMAnalysis_weights_None(universe, client_GNMAnalysis): gnm = mda.analysis.gnm.closeContactGNMAnalysis(universe, weights=None) - gnm.run(stop=2) + gnm.run(stop=2, **client_GNMAnalysis) result = gnm.results assert len(result.times) == 2 assert_almost_equal(gnm.results.times, (0, 100), decimal=4)