diff --git a/src/qrules/transition.py b/src/qrules/transition.py index cfd197b2..cc4e36e1 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -530,76 +530,20 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: return graph_settings - def find_solutions( # pylint: disable=too-many-branches - self, - problem_sets: Dict[float, List[ProblemSet]], + def find_solutions( + self, problem_sets: Dict[float, List[ProblemSet]] ) -> "ReactionInfo": - # pylint: disable=too-many-locals """Check for solutions for a specific set of interaction settings.""" - results: Dict[float, _SolutionContainer] = {} - logging.info( - "Number of interaction settings groups being processed: %d", - len(problem_sets), - ) - total = sum(map(len, problem_sets.values())) - progress_bar = tqdm( - total=total, - desc="Propagating quantum numbers", - disable=logging.getLogger().level > logging.WARNING, - ) - for strength, problems in sorted(problem_sets.items(), reverse=True): + results = self._find_particle_transitions(problem_sets) + for strength, result in results.items(): logging.info( - "processing interaction settings group with " - f"strength {strength}", - ) - logging.info(f"{len(problems)} entries in this group") - logging.info(f"running with {self.__number_of_threads} threads...") - - qn_problems = [x.to_qn_problem_set() for x in problems] - - # Because of pickling problems of Generic classes (in this case - # MutableTransition), multithreaded code has to work with - # QNProblemSet's and QNResult's. So the appropriate conversions - # have to be done before and after - temp_qn_results: List[Tuple[QNProblemSet, QNResult]] = [] - if self.__number_of_threads > 1: - with Pool(self.__number_of_threads) as pool: - for qn_result in pool.imap_unordered( - self._solve, qn_problems, chunksize=1 - ): - temp_qn_results.append(qn_result) - progress_bar.update() - else: - for problem in qn_problems: - temp_qn_results.append(self._solve(problem)) - progress_bar.update() - for temp_qn_result in temp_qn_results: - temp_result = self.__convert_result( - temp_qn_result[0].topology, - temp_qn_result[1], - ) - if strength not in results: - results[strength] = temp_result - else: - results[strength].extend( - temp_result, intersect_violations=True - ) - if ( - results[strength].solutions - and self.reaction_mode == SolvingMode.FAST - ): - break - progress_bar.close() - - for key, result in results.items(): - logging.info( - f"number of solutions for strength ({key}) " - f"after qn solving: {len(result.solutions)}", + f"Number of solutions for strength {strength} after" + f"QN solving: {len(result.solutions)}", ) final_result = _SolutionContainer() - for temp_result in results.values(): - final_result.extend(temp_result) + for particle_result in results.values(): + final_result.extend(particle_result) # remove duplicate solutions, which only differ in the interaction qns final_solutions = remove_duplicate_solutions( @@ -650,6 +594,74 @@ def find_solutions( # pylint: disable=too-many-branches ] return ReactionInfo(transitions, self.formalism) + def _find_particle_transitions( + self, problem_sets: Dict[float, List[ProblemSet]] + ) -> Dict[float, _SolutionContainer]: + qn_results = self.find_quantum_number_transitions(problem_sets) + results: Dict[float, _SolutionContainer] = {} + for strength, qn_solutions in qn_results.items(): + for qn_problem_set, qn_result in qn_solutions: + particle_result = self.__convert_to_particle_definitions( + qn_problem_set.topology, + qn_result, + ) + if strength not in results: + results[strength] = particle_result + else: + results[strength].extend( + particle_result, + intersect_violations=True, + ) + return results + + def find_quantum_number_transitions( + self, problem_sets: Dict[float, List[ProblemSet]] + ) -> Dict[float, List[Tuple[QNProblemSet, QNResult]]]: + """Find allowed transitions purely in terms of quantum number sets.""" + qn_results: Dict[ + float, List[Tuple[QNProblemSet, QNResult]] + ] = defaultdict(list) + logging.info( + "Number of interaction settings groups being processed: %d", + len(problem_sets), + ) + total = sum(map(len, problem_sets.values())) + progress_bar = tqdm( + total=total, + desc="Propagating quantum numbers", + disable=logging.getLogger().level > logging.WARNING, + ) + for strength, problems in sorted(problem_sets.items(), reverse=True): + logging.info( + "processing interaction settings group with " + f"strength {strength}", + ) + logging.info(f"{len(problems)} entries in this group") + logging.info(f"running with {self.__number_of_threads} threads...") + + qn_problems = [x.to_qn_problem_set() for x in problems] + + # Because of pickling problems of Generic classes (in this case + # MutableTransition), multithreaded code has to work with + # QNProblemSet's and QNResult's. So the appropriate conversions + # have to be done before and after + if self.__number_of_threads > 1: + with Pool(self.__number_of_threads) as pool: + for qn_solution in pool.imap_unordered( + self._solve, qn_problems, chunksize=1 + ): + qn_results[strength].append(qn_solution) + progress_bar.update() + else: + for problem in qn_problems: + qn_solution = self._solve(problem) + qn_results[strength].append(qn_solution) + progress_bar.update() + if qn_results[strength] and self.reaction_mode == SolvingMode.FAST: + break + progress_bar.close() + return qn_results + def _solve( self, qn_problem_set: QNProblemSet ) -> Tuple[QNProblemSet, QNResult]: @@ -657,7 +669,7 @@ def _solve( solutions = solver.find_solutions(qn_problem_set) return qn_problem_set, solutions - def __convert_result( + def __convert_to_particle_definitions( self, topology: Topology, qn_result: QNResult ) -> _SolutionContainer: """Converts a `.QNResult` with a `.Topology` into `.ReactionInfo`.