diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 0110536a..3bae8ae0 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -8,13 +8,26 @@ import string from collections import abc from functools import singledispatch +from inspect import isfunction from numbers import Number -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, cast +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) import attrs from attrs import Attribute, define, field from attrs.converters import default_if_none +from qrules.argument_handling import Rule from qrules.particle import Particle, ParticleWithSpin, Spin from qrules.quantum_numbers import InteractionProperties, _to_fraction from qrules.solving import EdgeSettings, NodeSettings, QNProblemSet, QNResult @@ -356,31 +369,45 @@ def _(obj: InteractionProperties) -> str: def _(settings: Union[EdgeSettings, NodeSettings]) -> str: output = "" if settings.rule_priorities: - output += "RULE PRIORITIES\n" - rule_names = ( - f"{item[0].__name__} - {item[1]}" # type: ignore[union-attr] - for item in settings.rule_priorities.items() + output += "RULES\n" + rule_descriptions = ( + f"{__render_rule(rule)} - {__get_priority(rule, settings.rule_priorities)}" + for rule in settings.conservation_rules ) - sorted_names = sorted(rule_names, key=__extract_priority, reverse=True) + sorted_names = sorted(rule_descriptions, key=__extract_priority, reverse=True) output += "\n".join(sorted_names) if settings.qn_domains: if output: output += "\n" domains = sorted( - f"{item[0].__name__} ∊ {item[1]}" for item in settings.qn_domains.items() + f"{qn.__name__} ∊ {domain}" for qn, domain in settings.qn_domains.items() ) output += "DOMAINS\n" output += "\n".join(domains) return output -def __extract_priority(description: str) -> int: - matches = re.match(r".* \- ([0-9]+)$", description) +def __get_priority(rule: Any, rule_priorities: Dict[Any, int]) -> Union[int, str]: + rule_type = __get_type(rule) + return rule_priorities.get(rule_type, "NA") + + +def __render_rule(rule: Rule) -> str: + return __get_type(rule).__name__ + + +def __get_type(rule: Rule) -> Type[Rule]: + if isfunction(rule): + return rule # type: ignore[return-value] + return type(rule) + + +def __extract_priority(description: str) -> str: + matches = re.match(r".* \- ([0-9]+|NA)$", description) if matches is None: msg = f"{description} does not contain a priority number" raise ValueError(msg) - priority = matches[1] - return int(priority) + return matches[1] @as_string.register(Particle)