Skip to content

Commit

Permalink
feat: extract find_quantum_number_transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Mar 31, 2022
1 parent fa68acb commit bc7a938
Showing 1 changed file with 77 additions and 65 deletions.
142 changes: 77 additions & 65 deletions src/qrules/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,76 +530,20 @@ def create_edge_settings(edge_id: int) -> EdgeSettings:

return graph_settings

def find_solutions( # pylint: disable=too-many-branches
self,
problem_sets: Dict[float, List[ProblemSet]],
def find_solutions(
self, problem_sets: Dict[float, List[ProblemSet]]
) -> "ReactionInfo":
# pylint: disable=too-many-locals
"""Check for solutions for a specific set of interaction settings."""
results: Dict[float, _SolutionContainer] = {}
logging.info(
"Number of interaction settings groups being processed: %d",
len(problem_sets),
)
total = sum(map(len, problem_sets.values()))
progress_bar = tqdm(
total=total,
desc="Propagating quantum numbers",
disable=logging.getLogger().level > logging.WARNING,
)
for strength, problems in sorted(problem_sets.items(), reverse=True):
results = self._find_particle_transitions(problem_sets)
for strength, result in results.items():
logging.info(
"processing interaction settings group with "
f"strength {strength}",
)
logging.info(f"{len(problems)} entries in this group")
logging.info(f"running with {self.__number_of_threads} threads...")

qn_problems = [x.to_qn_problem_set() for x in problems]

# Because of pickling problems of Generic classes (in this case
# MutableTransition), multithreaded code has to work with
# QNProblemSet's and QNResult's. So the appropriate conversions
# have to be done before and after
temp_qn_results: List[Tuple[QNProblemSet, QNResult]] = []
if self.__number_of_threads > 1:
with Pool(self.__number_of_threads) as pool:
for qn_result in pool.imap_unordered(
self._solve, qn_problems, chunksize=1
):
temp_qn_results.append(qn_result)
progress_bar.update()
else:
for problem in qn_problems:
temp_qn_results.append(self._solve(problem))
progress_bar.update()
for temp_qn_result in temp_qn_results:
temp_result = self.__convert_result(
temp_qn_result[0].topology,
temp_qn_result[1],
)
if strength not in results:
results[strength] = temp_result
else:
results[strength].extend(
temp_result, intersect_violations=True
)
if (
results[strength].solutions
and self.reaction_mode == SolvingMode.FAST
):
break
progress_bar.close()

for key, result in results.items():
logging.info(
f"number of solutions for strength ({key}) "
f"after qn solving: {len(result.solutions)}",
f"Number of solutions for strength {strength} after"
f"QN solving: {len(result.solutions)}",
)

final_result = _SolutionContainer()
for temp_result in results.values():
final_result.extend(temp_result)
for particle_result in results.values():
final_result.extend(particle_result)

# remove duplicate solutions, which only differ in the interaction qns
final_solutions = remove_duplicate_solutions(
Expand Down Expand Up @@ -650,14 +594,82 @@ def find_solutions( # pylint: disable=too-many-branches
]
return ReactionInfo(transitions, self.formalism)

def _find_particle_transitions(
self, problem_sets: Dict[float, List[ProblemSet]]
) -> Dict[float, _SolutionContainer]:
qn_results = self.find_quantum_number_transitions(problem_sets)
results: Dict[float, _SolutionContainer] = {}
for strength, qn_solutions in qn_results.items():
for qn_problem_set, qn_result in qn_solutions:
particle_result = self.__convert_to_particle_definitions(
qn_problem_set.topology,
qn_result,
)
if strength not in results:
results[strength] = particle_result
else:
results[strength].extend(
particle_result,
intersect_violations=True,
)
return results

def find_quantum_number_transitions(
self, problem_sets: Dict[float, List[ProblemSet]]
) -> Dict[float, List[Tuple[QNProblemSet, QNResult]]]:
"""Find allowed transitions purely in terms of quantum number sets."""
qn_results: Dict[
float, List[Tuple[QNProblemSet, QNResult]]
] = defaultdict(list)
logging.info(
"Number of interaction settings groups being processed: %d",
len(problem_sets),
)
total = sum(map(len, problem_sets.values()))
progress_bar = tqdm(
total=total,
desc="Propagating quantum numbers",
disable=logging.getLogger().level > logging.WARNING,
)
for strength, problems in sorted(problem_sets.items(), reverse=True):
logging.info(
"processing interaction settings group with "
f"strength {strength}",
)
logging.info(f"{len(problems)} entries in this group")
logging.info(f"running with {self.__number_of_threads} threads...")

qn_problems = [x.to_qn_problem_set() for x in problems]

# Because of pickling problems of Generic classes (in this case
# MutableTransition), multithreaded code has to work with
# QNProblemSet's and QNResult's. So the appropriate conversions
# have to be done before and after
if self.__number_of_threads > 1:
with Pool(self.__number_of_threads) as pool:
for qn_solution in pool.imap_unordered(
self._solve, qn_problems, chunksize=1
):
qn_results[strength].append(qn_solution)
progress_bar.update()
else:
for problem in qn_problems:
qn_solution = self._solve(problem)
qn_results[strength].append(qn_solution)
progress_bar.update()
if qn_results[strength] and self.reaction_mode == SolvingMode.FAST:
break
progress_bar.close()
return qn_results

def _solve(
self, qn_problem_set: QNProblemSet
) -> Tuple[QNProblemSet, QNResult]:
solver = CSPSolver(self.__allowed_intermediate_particles)
solutions = solver.find_solutions(qn_problem_set)
return qn_problem_set, solutions

def __convert_result(
def __convert_to_particle_definitions(
self, topology: Topology, qn_result: QNResult
) -> _SolutionContainer:
"""Converts a `.QNResult` with a `.Topology` into `.ReactionInfo`.
Expand Down

0 comments on commit bc7a938

Please sign in to comment.