Skip to content

Commit

Permalink
feat: implement global settings.NumberOfThreads (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Jan 22, 2022
1 parent 115de13 commit 11c420b
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 23 deletions.
5 changes: 3 additions & 2 deletions src/qrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,9 @@ def generate_transitions( # pylint: disable=too-many-arguments
- :code:`"nbody"`: Use one central node and connect initial and final
states to it
number_of_threads (int): Number of cores with which to compute the
allowed transitions. Defaults to all cores on the system.
number_of_threads: Number of cores with which to compute the allowed
transitions. Defaults to the current value returned by
:meth:`.settings.NumberOfThreads.get`.
An example (where, for illustrative purposes only, we specify all
arguments) would be:
Expand Down
21 changes: 21 additions & 0 deletions src/qrules/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
>>> qrules.settings.MAX_SPIN_MAGNITUDE = 3
"""

import multiprocessing
from copy import deepcopy
from enum import Enum, auto
from os.path import dirname, join, realpath
Expand Down Expand Up @@ -285,6 +286,26 @@ def _create_domains(particle_db: ParticleCollection) -> Dict[Any, list]:
return domains


class NumberOfThreads:
__n_cores: Optional[int] = None

@classmethod
def get(cls) -> int:
if cls.__n_cores is None:
return multiprocessing.cpu_count()
return cls.__n_cores

@classmethod
def set(cls, n_cores: Optional[int]) -> None: # noqa: A003
"""Set the number of threads; use `None` for all available cores."""
if n_cores is not None and not isinstance(n_cores, int):
raise TypeError(
"Can only set the number of cores to an integer or to None"
" (meaning all available cores)"
)
cls.__n_cores = n_cores


def __positive_halves_domain(
particle_db: ParticleCollection, attr_getter: Callable[[Particle], Any]
) -> List[float]:
Expand Down
22 changes: 12 additions & 10 deletions src/qrules/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Find allowed transitions between an initial and final state."""

import logging
import multiprocessing
from collections import abc, defaultdict
from copy import copy, deepcopy
from enum import Enum, auto
Expand Down Expand Up @@ -64,7 +63,11 @@
NodeQuantumNumber,
NodeQuantumNumbers,
)
from .settings import InteractionType, create_interaction_settings
from .settings import (
InteractionType,
NumberOfThreads,
create_interaction_settings,
)
from .solving import (
CSPSolver,
EdgeSettings,
Expand Down Expand Up @@ -250,13 +253,16 @@ def __init__( # pylint: disable=too-many-arguments, too-many-branches, too-many
] = None,
formalism: str = "helicity",
topology_building: str = "isobar",
number_of_threads: Optional[int] = None,
solving_mode: SolvingMode = SolvingMode.FAST,
reload_pdg: bool = False,
mass_conservation_factor: Optional[float] = 3.0,
max_angular_momentum: int = 1,
max_spin_magnitude: float = 2.0,
number_of_threads: Optional[int] = None,
) -> None:
if number_of_threads is not None:
NumberOfThreads.set(number_of_threads)
self.__number_of_threads = NumberOfThreads.get()
if interaction_type_settings is None:
interaction_type_settings = {}
allowed_formalisms = [
Expand All @@ -273,10 +279,6 @@ def __init__( # pylint: disable=too-many-arguments, too-many-branches, too-many
self.__particles = ParticleCollection()
if particle_db is not None:
self.__particles = particle_db
if number_of_threads is None:
self.number_of_threads = multiprocessing.cpu_count()
else:
self.number_of_threads = int(number_of_threads)
self.reaction_mode = str(solving_mode)
self.initial_state = initial_state
self.final_state = final_state
Expand Down Expand Up @@ -556,7 +558,7 @@ def find_solutions( # pylint: disable=too-many-branches
f"strength {strength}",
)
logging.info(f"{len(problems)} entries in this group")
logging.info(f"running with {self.number_of_threads} threads...")
logging.info(f"running with {self.__number_of_threads} threads...")

qn_problems = [x.to_qn_problem_set() for x in problems]

Expand All @@ -565,8 +567,8 @@ def find_solutions( # pylint: disable=too-many-branches
# QNProblemSet's and QNResult's. So the appropriate conversions
# have to be done before and after
temp_qn_results: List[Tuple[QNProblemSet, QNResult]] = []
if self.number_of_threads > 1:
with Pool(self.number_of_threads) as pool:
if self.__number_of_threads > 1:
with Pool(self.__number_of_threads) as pool:
for qn_result in pool.imap_unordered(
self._solve, qn_problems, 1
):
Expand Down
1 change: 0 additions & 1 deletion tests/channels/test_d0_to_ks_kp_km.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def test_script():
"a(2)(1320)-",
"phi(1020)",
],
number_of_threads=1,
)
assert len(reaction.transition_groups) == 3
assert len(reaction.transition_groups[0]) == 2
Expand Down
2 changes: 0 additions & 2 deletions tests/channels/test_jpsi_to_gamma_pi0_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_number_of_solutions(
particle_db=particle_database,
allowed_interaction_types=["strong", "EM"],
allowed_intermediate_particles=allowed_intermediate_particles,
number_of_threads=1,
formalism="helicity",
)
assert len(reaction.transition_groups) == n_topologies
Expand All @@ -57,7 +56,6 @@ def test_id_to_particle_mappings(particle_database):
particle_db=particle_database,
allowed_interaction_types="strong",
allowed_intermediate_particles=["f(0)(980)"],
number_of_threads=1,
formalism="helicity",
)
assert len(reaction.transition_groups) == 1
Expand Down
2 changes: 0 additions & 2 deletions tests/channels/test_y_to_d0_d0bar_pi0_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_simple(formalism, n_solutions, particle_database):
particle_db=particle_database,
formalism=formalism,
allowed_interaction_types="strong",
number_of_threads=1,
)
assert len(reaction.transition_groups) == 1
assert len(reaction.transitions) == n_solutions
Expand All @@ -39,7 +38,6 @@ def test_full(formalism, n_solutions, particle_database):
particle_db=particle_database,
allowed_intermediate_particles=["D*"],
formalism=formalism,
number_of_threads=1,
)
stm.set_allowed_interaction_types([InteractionType.STRONG])
stm.add_final_state_grouping([["D0", "pi0"], ["D~0", "pi0"]])
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

from qrules import load_default_particles
from qrules.particle import ParticleCollection
from qrules.settings import NumberOfThreads

# Ensure consistent test coverage when running pytest multithreaded
# https://github.com/ComPWA/qrules/issues/11
NumberOfThreads.set(1)


@pytest.fixture(scope="session")
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_parity_prefactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def test_parity_prefactor(
test_input.initial_state,
test_input.final_state,
allowed_intermediate_particles=test_input.intermediate_states,
number_of_threads=1,
)
stm.add_final_state_grouping(test_input.final_state_grouping)
stm.set_allowed_interaction_types([InteractionType.EM])
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/test_system_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def test_external_edge_initialization(
final_state,
particle_database,
formalism="helicity",
number_of_threads=1,
)

stm.set_allowed_interaction_types([InteractionType.STRONG])
Expand Down Expand Up @@ -366,7 +365,6 @@ def test_edge_swap(particle_database, initial_state, final_state):
final_state,
particle_database,
formalism="helicity",
number_of_threads=1,
)
stm.set_allowed_interaction_types([InteractionType.STRONG])

Expand Down Expand Up @@ -412,7 +410,6 @@ def test_match_external_edges(particle_database, initial_state, final_state):
final_state,
particle_database,
formalism="helicity",
number_of_threads=1,
)

stm.set_allowed_interaction_types([InteractionType.STRONG])
Expand Down Expand Up @@ -494,7 +491,6 @@ def test_external_edge_identical_particle_combinatorics(
final_state,
particle_database,
formalism="helicity",
number_of_threads=1,
)
stm.set_allowed_interaction_types([InteractionType.STRONG])
for group in final_state_groupings:
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def test_allowed_intermediate_particles(self):
stm = StateTransitionManager(
initial_state=[("J/psi(1S)", [-1, +1])],
final_state=["p", "p~", "eta"],
number_of_threads=1,
)
particle_name = "N(753)"
with pytest.raises(
Expand Down

0 comments on commit 11c420b

Please sign in to comment.