Skip to content

Commit

Permalink
Implementation of Parallelization to analysis.gnm (#4700)
Browse files Browse the repository at this point in the history
* Fixes #4672
* Changes made in this Pull Request:
    - Parallelization of the class GNMAnalysis in analysis.gnm.py
    - Addition of parallelization tests (including fixtures in analysis/conftest.py)
    - update of CHANGELOG

---------

Co-authored-by: Oliver Beckstein <orbeckst@gmail.com>
  • Loading branch information
talagayev and orbeckst authored Sep 9, 2024
1 parent 7618e05 commit b3208b3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
4 changes: 3 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ The rules for this file:
??/??/?? IAlibay, HeetVekariya, marinegor, lilyminium, RMeli,
ljwoods2, aditya292002, pstaerk, PicoCentauri, BFedder,
tyler.je.reddy, SampurnaM, leonwehrhan, kainszs, orionarcher,
yuxuanzhuang, PythonFZ, laksh-krishna-sharma, orbeckst, MattTDavies
yuxuanzhuang, PythonFZ, laksh-krishna-sharma, orbeckst, MattTDavies,
talagayev

* 2.8.0

Expand Down Expand Up @@ -55,6 +56,7 @@ Fixes
Enhancements
* Introduce parallelization API to `AnalysisBase` and to `analysis.rms.RMSD` class
(Issue #4158, PR #4304)
* Enables parallelization for analysis.gnm.GNMAnalysis (Issue #4672)
* explicitly mark `analysis.pca.PCA` as not parallelizable (Issue #4680)
* Improve error message for `AtomGroup.unwrap()` when bonds are not present.(Issue #4436, PR #4642)
* Add `analysis.DSSP` module for protein secondary structure assignment, based on [pydssp](https://github.com/ShintaroMinami/PyDSSP)
Expand Down
22 changes: 21 additions & 1 deletion package/MDAnalysis/analysis/gnm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

import numpy as np

from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup


from MDAnalysis.analysis.base import Results
Expand Down Expand Up @@ -245,8 +245,19 @@ class GNMAnalysis(AnalysisBase):
Use :class:`~MDAnalysis.analysis.AnalysisBase` as parent class and
store results as attributes ``times``, ``eigenvalues`` and
``eigenvectors`` of the ``results`` attribute.
.. versionchanged:: 2.8.0
Enabled **parallel execution** with the ``multiprocessing`` and ``dask``
backends; use the new method :meth:`get_supported_backends` to see all
supported backends.
"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return ("serial", "multiprocessing", "dask")

def __init__(self,
universe,
select='protein and name CA',
Expand Down Expand Up @@ -348,6 +359,15 @@ def _conclude(self):
self.results.eigenvalues = np.asarray(self.results.eigenvalues)
self.results.eigenvectors = np.asarray(self.results.eigenvectors)

def _get_aggregator(self):
return ResultsGroup(
lookup={
"eigenvectors": ResultsGroup.ndarray_hstack,
"eigenvalues": ResultsGroup.ndarray_hstack,
"times": ResultsGroup.ndarray_hstack,
}
)


class closeContactGNMAnalysis(GNMAnalysis):
r"""GNMAnalysis only using close contacts.
Expand Down
6 changes: 6 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from MDAnalysis.analysis.rms import RMSD, RMSF
from MDAnalysis.lib.util import is_installed
from MDAnalysis.analysis.gnm import GNMAnalysis


def params_for_cls(cls, exclude: list[str] = None):
Expand Down Expand Up @@ -87,3 +88,8 @@ def client_RMSD(request):
@pytest.fixture(scope='module', params=params_for_cls(RMSF))
def client_RMSF(request):
return request.param


@pytest.fixture(scope='module', params=params_for_cls(GNMAnalysis))
def client_GNMAnalysis(request):
return request.param
16 changes: 8 additions & 8 deletions testsuite/MDAnalysisTests/analysis/test_gnm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def universe():
return mda.Universe(GRO, XTC)


def test_gnm(universe, tmpdir):
def test_gnm(universe, tmpdir, client_GNMAnalysis):
output = os.path.join(str(tmpdir), 'output.txt')
gnm = mda.analysis.gnm.GNMAnalysis(universe, ReportVector=output)
gnm.run()
gnm.run(**client_GNMAnalysis)
result = gnm.results
assert len(result.times) == 10
assert_almost_equal(gnm.results.times, np.arange(0, 1000, 100), decimal=4)
Expand All @@ -51,9 +51,9 @@ def test_gnm(universe, tmpdir):
4.2058769e-15, 3.9839431e-15])


def test_gnm_run_step(universe):
def test_gnm_run_step(universe, client_GNMAnalysis):
gnm = mda.analysis.gnm.GNMAnalysis(universe)
gnm.run(step=3)
gnm.run(step=3, **client_GNMAnalysis)
result = gnm.results
assert len(result.times) == 4
assert_almost_equal(gnm.results.times, np.arange(0, 1200, 300), decimal=4)
Expand Down Expand Up @@ -88,9 +88,9 @@ def test_gnm_SVD_fail(universe):
mda.analysis.gnm.GNMAnalysis(universe).run(stop=1)


def test_closeContactGNMAnalysis(universe):
def test_closeContactGNMAnalysis(universe, client_GNMAnalysis):
gnm = mda.analysis.gnm.closeContactGNMAnalysis(universe, weights="size")
gnm.run(stop=2)
gnm.run(stop=2, **client_GNMAnalysis)
result = gnm.results
assert len(result.times) == 2
assert_almost_equal(gnm.results.times, (0, 100), decimal=4)
Expand All @@ -114,9 +114,9 @@ def test_closeContactGNMAnalysis(universe):
0.0, 0.0, -2.263157894736841, -0.24333213169614382])


def test_closeContactGNMAnalysis_weights_None(universe):
def test_closeContactGNMAnalysis_weights_None(universe, client_GNMAnalysis):
gnm = mda.analysis.gnm.closeContactGNMAnalysis(universe, weights=None)
gnm.run(stop=2)
gnm.run(stop=2, **client_GNMAnalysis)
result = gnm.results
assert len(result.times) == 2
assert_almost_equal(gnm.results.times, (0, 100), decimal=4)
Expand Down

0 comments on commit b3208b3

Please sign in to comment.