diff --git a/.github/workflows/requirements-cron.yml b/.github/workflows/requirements-cron.yml index 095551a7..367d8dae 100644 --- a/.github/workflows/requirements-cron.yml +++ b/.github/workflows/requirements-cron.yml @@ -2,7 +2,7 @@ name: Requirements (scheduled) on: schedule: - - cron: "0 2 */14 * *" + - cron: "0 2 * * 1" workflow_dispatch: jobs: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b83a06c7..c45fccd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -183,5 +183,6 @@ repos: - --rcfile=.pylintrc - --score=no language: system + require_serial: true types: - python diff --git a/.pylintrc b/.pylintrc index 270e263a..c32808b8 100644 --- a/.pylintrc +++ b/.pylintrc @@ -15,13 +15,20 @@ ignore-patterns= [MESSAGES CONTROL] disable= duplicate-code, # https://github.com/PyCQA/pylint/issues/214 + invalid-unary-operand-type, # conflicts with attrs.field logging-fstring-interpolation, missing-class-docstring, # pydocstyle missing-function-docstring, # pydocstyle missing-module-docstring, # pydocstyle + no-member, # conflicts with attrs.field + not-an-iterable, # conflicts with attrs.field + not-callable, # conflicts with attrs.field redefined-builtin, # flake8-built too-few-public-methods, # data containers (attrs) and interface classes unspecified-encoding, # http://pylint.pycqa.org/en/latest/whatsnew/2.10.html + unsubscriptable-object, # conflicts with attrs.field + unsupported-assignment-operation, # conflicts with attrs.field + unsupported-membership-test, # conflicts with attrs.field unused-import, # https://www.flake8rules.com/rules/F401 [SIMILARITIES] diff --git a/docs/_relink_references.py b/docs/_relink_references.py index f897ad2d..65cb1ac0 100644 --- a/docs/_relink_references.py +++ b/docs/_relink_references.py @@ -1,4 +1,3 @@ -# cspell:ignore docutils # pylint: disable=import-error, import-outside-toplevel # pyright: reportMissingImports=false """Abbreviated the annotations generated by sphinx-autodoc. @@ -16,28 +15,22 @@ from sphinx.environment import BuildEnvironment -def _replace_link(text: str) -> str: - replacements = { - "a set-like object providing a view on D's items": "typing.ItemsView", - "a set-like object providing a view on D's keys": "typing.KeysView", - "an object providing a view on D's values": "typing.ValuesView", - "typing_extensions.Protocol": "typing.Protocol", - } - for old, new in replacements.items(): - if text == old: - return new - return text +__TARGET_SUBSTITUTIONS = { + "a set-like object providing a view on D's items": "typing.ItemsView", + "a set-like object providing a view on D's keys": "typing.KeysView", + "an object providing a view on D's values": "typing.ValuesView", + "typing_extensions.Protocol": "typing.Protocol", +} +__REF_TYPE_SUBSTITUTIONS = { + "None": "obj", +} def _new_type_to_xref( - text: str, env: "BuildEnvironment" = None + target: str, + env: "BuildEnvironment" = None, + suppress_prefix: bool = False, ) -> "pending_xref": - """Convert a type string to a cross reference node.""" - if text == "None": - reftype = "obj" - else: - reftype = "class" - if env: kwargs = { "py:module": env.ref_context.get("py:module"), @@ -46,8 +39,12 @@ def _new_type_to_xref( else: kwargs = {} - text = _replace_link(text) - short_text = text.split(".")[-1] + target = __TARGET_SUBSTITUTIONS.get(target, target) + reftype = __REF_TYPE_SUBSTITUTIONS.get(target, "class") + if suppress_prefix: + short_text = target.split(".")[-1] + else: + short_text = target from docutils.nodes import Text from sphinx.addnodes import pending_xref @@ -57,7 +54,7 @@ def _new_type_to_xref( Text(short_text), refdomain="py", reftype=reftype, - reftarget=text, + reftarget=target, **kwargs, ) diff --git a/docs/conf.py b/docs/conf.py index 0fad07d9..16986856 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -223,6 +223,7 @@ def fetch_logo(url: str, output_path: str) -> None: ("py:class", "NoneType"), ("py:class", "StateTransitionGraph"), ("py:class", "ValueType"), + ("py:class", "json.encoder.JSONEncoder"), ("py:class", "typing_extensions.Protocol"), ("py:obj", "qrules.topology._K"), ("py:obj", "qrules.topology._V"), diff --git a/docs/usage/conservation.ipynb b/docs/usage/conservation.ipynb index 97c1d083..c61d84e0 100644 --- a/docs/usage/conservation.ipynb +++ b/docs/usage/conservation.ipynb @@ -77,7 +77,7 @@ }, "outputs": [], "source": [ - "import attr\n", + "import attrs\n", "import graphviz\n", "\n", "import qrules\n", @@ -377,7 +377,7 @@ "metadata": {}, "outputs": [], "source": [ - "new_interaction = attr.evolve(transition.interactions[node_id], l_magnitude=2)\n", + "new_interaction = attrs.evolve(transition.interactions[node_id], l_magnitude=2)\n", "new_interaction" ] }, @@ -396,7 +396,7 @@ "source": [ "new_interaction_dict = dict(transition.interactions) # make mutable\n", "new_interaction_dict[node_id] = new_interaction\n", - "new_transition = attr.evolve(transition, interactions=new_interaction_dict)" + "new_transition = attrs.evolve(transition, interactions=new_interaction_dict)" ] }, { diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index da75b08c..2127df3d 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -29,7 +29,7 @@ Union, ) -import attr +import attrs from . import io from .combinatorics import InitialFacts, StateDefinition, create_initial_facts @@ -237,7 +237,7 @@ def check_edge_qn_conservation() -> Set[FrozenSet[str]]: initial_facts_list = [] for ls_combi in ls_combinations: for facts_combination in initial_facts: - new_facts = attr.evolve( + new_facts = attrs.evolve( facts_combination, node_props={node_id: ls_combi}, ) diff --git a/src/qrules/_implementers.py b/src/qrules/_implementers.py index 4967e869..9224f7fd 100644 --- a/src/qrules/_implementers.py +++ b/src/qrules/_implementers.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar -import attr +import attrs if TYPE_CHECKING: try: @@ -17,12 +17,12 @@ def implement_pretty_repr() -> Callable[ [Type[_DecoratedClass]], Type[_DecoratedClass] ]: - """Implement a pretty :code:`repr` in a `attr` decorated class.""" + """Implement a pretty :code:`repr` in a class decorated by `attrs`.""" def decorator( decorated_class: Type[_DecoratedClass], ) -> Type[_DecoratedClass]: - if not attr.has(decorated_class): + if not attrs.has(decorated_class): raise TypeError( "Can only implement a pretty repr for a class created with" " attrs" @@ -34,7 +34,7 @@ def repr_pretty(self: Any, p: "PrettyPrinter", cycle: bool) -> None: p.text(f"{class_name}(...)") else: with p.group(indent=2, open=f"{class_name}("): - for field in attr.fields(type(self)): + for field in attrs.fields(type(self)): if not field.init: continue value = getattr(self, field.name) diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index c909df92..3b9a2c3e 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Type -import attr +import attrs from .particle import Particle, ParticleCollection, ParticleWithSpin from .quantum_numbers import ( @@ -34,10 +34,10 @@ def create_edge_properties( qn_name: qn_type for qn_name, qn_type in EdgeQuantumNumbers.__dict__.items() if not qn_name.startswith("__") - } # Note using attr.fields does not work here because init=False + } # Note using attrs.fields does not work here because init=False property_map: GraphEdgePropertyMap = {} isospin = None - for qn_name, value in attr.asdict(particle, recurse=False).items(): + for qn_name, value in attrs.asdict(particle, recurse=False).items(): if isinstance(value, Parity): value = value.value if qn_name in edge_qn_mapping: @@ -65,9 +65,9 @@ def create_node_properties( qn_name: qn_type for qn_name, qn_type in NodeQuantumNumbers.__dict__.items() if not qn_name.startswith("__") - } # Note using attr.fields does not work here because init=False + } # Note using attrs.fields does not work here because init=False property_map: GraphNodePropertyMap = {} - for qn_name, value in attr.asdict(node_props).items(): + for qn_name, value in attrs.asdict(node_props).items(): if value is None: continue if qn_name in node_qn_mapping: @@ -117,11 +117,11 @@ def create_interaction_properties( converted_solution = {k.__name__: v for k, v in qn_solution.items()} kw_args = { x.name: converted_solution[x.name] - for x in attr.fields(InteractionProperties) + for x in attrs.fields(InteractionProperties) if x.name in converted_solution } - return attr.evolve(InteractionProperties(), **kw_args) + return attrs.evolve(InteractionProperties(), **kw_args) def filter_interaction_types( @@ -234,7 +234,7 @@ def _remove_qns_from_graph( # pylint: disable=too-many-branches new_node_props = {} for node_id in graph.topology.nodes: node_props = graph.get_node_props(node_id) - new_node_props[node_id] = attr.evolve( + new_node_props[node_id] = attrs.evolve( node_props, **{x.__name__: None for x in qn_list} ) @@ -279,10 +279,10 @@ def __call__( node_props1: InteractionProperties, node_props2: InteractionProperties, ) -> bool: - return attr.evolve( + return attrs.evolve( node_props1, **{x.__name__: None for x in self.__ignored_qn_list}, - ) == attr.evolve( + ) == attrs.evolve( node_props2, **{x.__name__: None for x in self.__ignored_qn_list}, ) diff --git a/src/qrules/argument_handling.py b/src/qrules/argument_handling.py index 593cdaa0..8c200f19 100644 --- a/src/qrules/argument_handling.py +++ b/src/qrules/argument_handling.py @@ -21,7 +21,7 @@ Union, ) -import attr +import attrs from .conservation_rules import ( ConservationRule, @@ -154,7 +154,7 @@ def __init__(self, class_type: type) -> None: ) if _is_edge_quantum_number(class_field.type) else _ValueExtractor[NodeQuantumNumber](class_field.type) - for class_field in attr.fields(class_type) + for class_field in attrs.fields(class_type) } def __call__( @@ -206,10 +206,10 @@ def __create_requirements_check( qn_type = input_type.__args__[0] # type: ignore[attr-defined] is_list = True - if attr.has(qn_type): + if attrs.has(qn_type): class_field_types = [ class_field.type - for class_field in attr.fields(qn_type) + for class_field in attrs.fields(qn_type) if not _is_optional(class_field.type) ] qn_check_function: Callable[ @@ -239,7 +239,7 @@ def __create_argument_builder( qn_type = input_type.__args__[0] # type: ignore[attr-defined] is_list = True - if attr.has(qn_type): + if attrs.has(qn_type): arg_builder: Callable[..., Any] = _CompositeArgumentCreator( qn_type ) @@ -322,8 +322,8 @@ def get_required_qns( if _is_sequence_type(input_type): class_type = input_type.__args__[0] - if attr.has(class_type): - for class_field in attr.fields(class_type): + if attrs.has(class_type): + for class_field in attrs.fields(class_type): field_type = ( class_field.type.__args__[0] # type: ignore[union-attr] if _is_optional(class_field.type) diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 35263559..09ee0f1a 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -24,7 +24,7 @@ Union, ) -import attr +from attrs import field, frozen from qrules._implementers import implement_pretty_repr from qrules.particle import Particle, ParticleCollection @@ -38,10 +38,10 @@ @implement_pretty_repr() -@attr.frozen +@frozen class InitialFacts: - edge_props: Dict[int, ParticleWithSpin] = attr.ib(factory=dict) - node_props: Dict[int, InteractionProperties] = attr.ib(factory=dict) + edge_props: Dict[int, ParticleWithSpin] = field(factory=dict) + node_props: Dict[int, InteractionProperties] = field(factory=dict) class _KinematicRepresentation: diff --git a/src/qrules/conservation_rules.py b/src/qrules/conservation_rules.py index f24d86e7..e3b936b9 100644 --- a/src/qrules/conservation_rules.py +++ b/src/qrules/conservation_rules.py @@ -53,8 +53,8 @@ from functools import reduce from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union -import attr -from attr.converters import optional +from attrs import define, field, frozen +from attrs.converters import optional from .quantum_numbers import EdgeQuantumNumbers as EdgeQN from .quantum_numbers import NodeQuantumNumbers as NodeQN @@ -189,13 +189,13 @@ def parity_conservation( return True -@attr.frozen +@frozen class HelicityParityEdgeInput: - parity: EdgeQN.parity = attr.ib(converter=EdgeQN.parity) - spin_magnitude: EdgeQN.spin_magnitude = attr.ib( + parity: EdgeQN.parity = field(converter=EdgeQN.parity) + spin_magnitude: EdgeQN.spin_magnitude = field( converter=EdgeQN.spin_magnitude ) - spin_projection: EdgeQN.spin_projection = attr.ib( + spin_projection: EdgeQN.spin_projection = field( converter=EdgeQN.spin_projection ) @@ -239,21 +239,21 @@ def parity_conservation_helicity( return True -@attr.frozen +@frozen class CParityEdgeInput: - spin_magnitude: EdgeQN.spin_magnitude = attr.ib( + spin_magnitude: EdgeQN.spin_magnitude = field( converter=EdgeQN.spin_magnitude ) - pid: EdgeQN.pid = attr.ib(converter=EdgeQN.pid) - c_parity: Optional[EdgeQN.c_parity] = attr.ib( + pid: EdgeQN.pid = field(converter=EdgeQN.pid) + c_parity: Optional[EdgeQN.c_parity] = field( converter=EdgeQN.c_parity, default=None ) -@attr.frozen +@frozen class CParityNodeInput: - l_magnitude: NodeQN.l_magnitude = attr.ib(converter=NodeQN.l_magnitude) - s_magnitude: NodeQN.s_magnitude = attr.ib(converter=NodeQN.s_magnitude) + l_magnitude: NodeQN.l_magnitude = field(converter=NodeQN.l_magnitude) + s_magnitude: NodeQN.s_magnitude = field(converter=NodeQN.s_magnitude) def c_parity_conservation( @@ -303,24 +303,24 @@ def _get_c_parity_multiparticle( return c_parity_in == c_parity_out -@attr.frozen +@frozen class GParityEdgeInput: - isospin_magnitude: EdgeQN.isospin_magnitude = attr.ib( + isospin_magnitude: EdgeQN.isospin_magnitude = field( converter=EdgeQN.isospin_magnitude ) - spin_magnitude: EdgeQN.spin_magnitude = attr.ib( + spin_magnitude: EdgeQN.spin_magnitude = field( converter=EdgeQN.spin_magnitude ) - pid: EdgeQN.pid = attr.ib(converter=EdgeQN.pid) - g_parity: Optional[EdgeQN.g_parity] = attr.ib( + pid: EdgeQN.pid = field(converter=EdgeQN.pid) + g_parity: Optional[EdgeQN.g_parity] = field( converter=EdgeQN.g_parity, default=None ) -@attr.frozen +@frozen class GParityNodeInput: - l_magnitude: NodeQN.l_magnitude = attr.ib(converter=NodeQN.l_magnitude) - s_magnitude: NodeQN.s_magnitude = attr.ib(converter=NodeQN.s_magnitude) + l_magnitude: NodeQN.l_magnitude = field(converter=NodeQN.l_magnitude) + s_magnitude: NodeQN.s_magnitude = field(converter=NodeQN.s_magnitude) def g_parity_conservation( @@ -404,15 +404,15 @@ def check_g_parity_isobar( return True -@attr.frozen +@frozen class IdenticalParticleSymmetryOutEdgeInput: - spin_magnitude: EdgeQN.spin_magnitude = attr.ib( + spin_magnitude: EdgeQN.spin_magnitude = field( converter=EdgeQN.spin_magnitude ) - spin_projection: EdgeQN.spin_projection = attr.ib( + spin_projection: EdgeQN.spin_projection = field( converter=EdgeQN.spin_projection ) - pid: EdgeQN.pid = attr.ib(converter=EdgeQN.pid) + pid: EdgeQN.pid = field(converter=EdgeQN.pid) def identical_particle_symmetrization( @@ -466,10 +466,10 @@ def _check_particles_identical( return True -@attr.frozen +@frozen class _Spin: - magnitude: float = attr.ib() - projection: float = attr.ib() + magnitude: float + projection: float def _is_clebsch_gordan_coefficient_zero( @@ -493,18 +493,18 @@ def _is_clebsch_gordan_coefficient_zero( return False -@attr.frozen +@frozen class SpinNodeInput: - l_magnitude: NodeQN.l_magnitude = attr.ib(converter=NodeQN.l_magnitude) - l_projection: NodeQN.l_projection = attr.ib(converter=NodeQN.l_projection) - s_magnitude: NodeQN.s_magnitude = attr.ib(converter=NodeQN.s_magnitude) - s_projection: NodeQN.s_projection = attr.ib(converter=NodeQN.s_projection) + l_magnitude: NodeQN.l_magnitude = field(converter=NodeQN.l_magnitude) + l_projection: NodeQN.l_projection = field(converter=NodeQN.l_projection) + s_magnitude: NodeQN.s_magnitude = field(converter=NodeQN.s_magnitude) + s_projection: NodeQN.s_projection = field(converter=NodeQN.s_projection) -@attr.frozen +@frozen class SpinMagnitudeNodeInput: - l_magnitude: NodeQN.l_magnitude = attr.ib(converter=NodeQN.l_magnitude) - s_magnitude: NodeQN.s_magnitude = attr.ib(converter=NodeQN.s_magnitude) + l_magnitude: NodeQN.l_magnitude = field(converter=NodeQN.l_magnitude) + s_magnitude: NodeQN.s_magnitude = field(converter=NodeQN.s_magnitude) def ls_spin_validity(spin_input: SpinNodeInput) -> bool: @@ -639,12 +639,12 @@ def __spin_couplings(spin1: _Spin, spin2: _Spin) -> Set[_Spin]: } -@attr.define +@define class IsoSpinEdgeInput: - isospin_magnitude: EdgeQN.isospin_magnitude = attr.ib( + isospin_magnitude: EdgeQN.isospin_magnitude = field( converter=EdgeQN.isospin_magnitude ) - isospin_projection: EdgeQN.isospin_projection = attr.ib( + isospin_projection: EdgeQN.isospin_projection = field( converter=EdgeQN.isospin_projection ) @@ -699,12 +699,12 @@ def isospin_conservation( ) -@attr.define +@define class SpinEdgeInput: - spin_magnitude: EdgeQN.spin_magnitude = attr.ib( + spin_magnitude: EdgeQN.spin_magnitude = field( converter=EdgeQN.spin_magnitude ) - spin_projection: EdgeQN.spin_projection = attr.ib( + spin_projection: EdgeQN.spin_projection = field( converter=EdgeQN.spin_projection ) @@ -873,35 +873,35 @@ def helicity_conservation( return False -@attr.frozen +@frozen class GellMannNishijimaInput: # pylint: disable=too-many-instance-attributes - charge: EdgeQN.charge = attr.ib(converter=EdgeQN.charge) - isospin_projection: Optional[EdgeQN.isospin_projection] = attr.ib( + charge: EdgeQN.charge = field(converter=EdgeQN.charge) + isospin_projection: Optional[EdgeQN.isospin_projection] = field( converter=optional(EdgeQN.isospin_projection), default=None ) - strangeness: Optional[EdgeQN.strangeness] = attr.ib( + strangeness: Optional[EdgeQN.strangeness] = field( converter=optional(EdgeQN.strangeness), default=None ) - charmness: Optional[EdgeQN.charmness] = attr.ib( + charmness: Optional[EdgeQN.charmness] = field( converter=optional(EdgeQN.charmness), default=None ) - bottomness: Optional[EdgeQN.bottomness] = attr.ib( + bottomness: Optional[EdgeQN.bottomness] = field( converter=optional(EdgeQN.bottomness), default=None ) - topness: Optional[EdgeQN.topness] = attr.ib( + topness: Optional[EdgeQN.topness] = field( converter=optional(EdgeQN.topness), default=None ) - baryon_number: Optional[EdgeQN.baryon_number] = attr.ib( + baryon_number: Optional[EdgeQN.baryon_number] = field( converter=optional(EdgeQN.baryon_number), default=None ) - electron_lepton_number: Optional[EdgeQN.electron_lepton_number] = attr.ib( + electron_lepton_number: Optional[EdgeQN.electron_lepton_number] = field( converter=optional(EdgeQN.electron_lepton_number), default=None ) - muon_lepton_number: Optional[EdgeQN.muon_lepton_number] = attr.ib( + muon_lepton_number: Optional[EdgeQN.muon_lepton_number] = field( converter=optional(EdgeQN.muon_lepton_number), default=None ) - tau_lepton_number: Optional[EdgeQN.tau_lepton_number] = attr.ib( + tau_lepton_number: Optional[EdgeQN.tau_lepton_number] = field( converter=optional(EdgeQN.tau_lepton_number), default=None ) @@ -956,12 +956,10 @@ def calculate_hypercharge( return True -@attr.frozen +@frozen class MassEdgeInput: - mass: EdgeQN.mass = attr.ib(converter=EdgeQN.mass) - width: Optional[EdgeQN.width] = attr.ib( - converter=EdgeQN.width, default=None - ) + mass: EdgeQN.mass = field(converter=EdgeQN.mass) + width: Optional[EdgeQN.width] = field(converter=EdgeQN.width, default=None) class MassConservation: diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index edeade86..9d3a794d 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import Any, Dict, Optional -import attr +import attrs import yaml from qrules.particle import Particle, ParticleCollection @@ -37,10 +37,10 @@ def asdict(instance: object) -> dict: instance, (ReactionInfo, State, StateTransition, StateTransitionCollection), ): - return attr.asdict( + return attrs.asdict( instance, recurse=True, - filter=lambda attr, _: attr.init, + filter=lambda a, _: a.init, value_serializer=_dict._value_serializer, ) if isinstance(instance, StateTransitionGraph): @@ -74,11 +74,11 @@ def fromdict(definition: dict) -> object: __REQUIRED_PARTICLE_FIELDS = { field.name - for field in attr.fields(Particle) - if field.default == attr.NOTHING + for field in attrs.fields(Particle) + if field.default == attrs.NOTHING } __REQUIRED_TOPOLOGY_FIELDS = { - field.name for field in attr.fields(Topology) if field.init + field.name for field in attrs.fields(Topology) if field.init } @@ -205,7 +205,7 @@ def write(instance: object, filename: str) -> None: with open(filename, "w") as stream: file_extension = _get_file_extension(filename) if file_extension == "json": - json.dump(asdict(instance), stream, indent=2) + json.dump(asdict(instance), stream, indent=2, cls=JSONSetEncoder) return if file_extension in ["yaml", "yml"]: yaml.dump( @@ -236,3 +236,19 @@ def _get_file_extension(filename: str) -> str: raise ValueError(f'No file extension in file name "{filename}"') extension = extension[1:] return extension + + +class JSONSetEncoder(json.JSONEncoder): + """`~json.JSONEncoder` that supports `set` and `frozenset`. + + >>> import json + >>> instance = {"val1": {1, 2, 3}, "val2": frozenset({2, 3, 4, 5})} + >>> json.dumps(instance, cls=JSONSetEncoder) + '{"val1": [1, 2, 3], "val2": [2, 3, 4, 5]}' + """ + + # https://stackoverflow.com/a/8230505 + def default(self, o: Any) -> Any: + if isinstance(o, (frozenset, set)): + return list(o) + return super().default(o) diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index c0a23095..89cd18d0 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -6,7 +6,7 @@ from os.path import dirname, realpath from typing import Any, Dict -import attr +import attrs from qrules.particle import ( Parity, @@ -30,11 +30,11 @@ def from_particle_collection(particles: ParticleCollection) -> dict: def from_particle(particle: Particle) -> dict: - return attr.asdict( + return attrs.asdict( particle, recurse=True, value_serializer=_value_serializer, - filter=lambda attr, value: attr.default != value, + filter=lambda attribute, value: attribute.default != value, ) @@ -52,7 +52,7 @@ def from_stg(graph: StateTransitionGraph[ParticleWithSpin]) -> dict: node_props_def = {} for i in topology.nodes: node_prop = graph.get_node_props(i) - node_props_def[i] = attr.asdict( + node_props_def[i] = attrs.asdict( node_prop, filter=lambda a, v: a.init and a.default != v ) return { @@ -63,7 +63,7 @@ def from_stg(graph: StateTransitionGraph[ParticleWithSpin]) -> dict: def from_topology(topology: Topology) -> dict: - return attr.asdict( + return attrs.asdict( topology, recurse=True, value_serializer=_value_serializer, @@ -72,7 +72,7 @@ def from_topology(topology: Topology) -> dict: def _value_serializer( # pylint: disable=unused-argument - inst: type, field: attr.Attribute, value: Any + inst: type, field: attrs.Attribute, value: Any ) -> Any: if isinstance(value, abc.Mapping): if all(map(lambda p: isinstance(p, Particle), value.values())): diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 02151c8e..c3d57583 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -18,7 +18,7 @@ Union, ) -import attr +import attrs from qrules.combinatorics import InitialFacts from qrules.particle import Particle, ParticleCollection, ParticleWithSpin @@ -511,7 +511,7 @@ def _strip_projections( for node_id in graph.topology.nodes: node_props = graph.get_node_props(node_id) if node_props: - new_node_props[node_id] = attr.evolve( + new_node_props[node_id] = attrs.evolve( node_props, l_projection=None, s_projection=None ) return StateTransitionGraph[Particle]( diff --git a/src/qrules/particle.py b/src/qrules/particle.py index b7379b0a..6f2c1ae8 100644 --- a/src/qrules/particle.py +++ b/src/qrules/particle.py @@ -32,9 +32,10 @@ Union, ) -import attr -from attr.converters import optional -from attr.validators import instance_of +import attrs +from attrs import field, frozen +from attrs.converters import optional +from attrs.validators import instance_of from .conservation_rules import GellMannNishijimaInput, gellmann_nishijima from .quantum_numbers import Parity, _to_fraction @@ -57,12 +58,12 @@ def _to_float(value: SupportsFloat) -> float: @total_ordering -@attr.frozen(eq=False, hash=True, order=False) +@frozen(eq=False, hash=True, order=False) class Spin: """Safe, immutable data container for spin **with projection**.""" - magnitude: float = attr.ib(converter=_to_float) - projection: float = attr.ib(converter=_to_float) + magnitude: float = field(converter=_to_float) + projection: float = field(converter=_to_float) def __attrs_post_init__(self) -> None: if self.magnitude % 0.5 != 0.0: @@ -99,7 +100,7 @@ def __float__(self) -> float: def __gt__(self, other: Any) -> bool: if isinstance(other, Spin): - return attr.astuple(self) > attr.astuple(other) + return attrs.astuple(self) > attrs.astuple(other) return self.magnitude > other def __neg__(self) -> "Spin": @@ -126,7 +127,7 @@ def _to_spin(value: Union[Spin, Tuple[float, float]]) -> Spin: @total_ordering -@attr.frozen(kw_only=True, order=False, repr=True) +@frozen(kw_only=True, order=False, repr=True) class Particle: # pylint: disable=too-many-instance-attributes """Immutable container of data defining a physical particle. @@ -147,34 +148,30 @@ class Particle: # pylint: disable=too-many-instance-attributes """ # Labels - name: str = attr.ib(eq=False) - pid: int = attr.ib(eq=False) - latex: Optional[str] = attr.ib(eq=False, default=None) + name: str = field(eq=False) + pid: int = field(eq=False) + latex: Optional[str] = field(eq=False, default=None) # Unique properties - spin: float = attr.ib(converter=float) - mass: float = attr.ib(converter=float) - width: float = attr.ib(converter=float, default=0.0) - charge: int = attr.ib(default=0) - isospin: Optional[Spin] = attr.ib( - converter=optional(_to_spin), default=None - ) - strangeness: int = attr.ib(default=0, validator=instance_of(int)) - charmness: int = attr.ib(default=0, validator=instance_of(int)) - bottomness: int = attr.ib(default=0, validator=instance_of(int)) - topness: int = attr.ib(default=0, validator=instance_of(int)) - baryon_number: int = attr.ib(default=0, validator=instance_of(int)) - electron_lepton_number: int = attr.ib( - default=0, validator=instance_of(int) - ) - muon_lepton_number: int = attr.ib(default=0, validator=instance_of(int)) - tau_lepton_number: int = attr.ib(default=0, validator=instance_of(int)) - parity: Optional[Parity] = attr.ib( + spin: float = field(converter=float) + mass: float = field(converter=float) + width: float = field(converter=float, default=0.0) + charge: int = field(default=0) + isospin: Optional[Spin] = field(converter=optional(_to_spin), default=None) + strangeness: int = field(default=0, validator=instance_of(int)) + charmness: int = field(default=0, validator=instance_of(int)) + bottomness: int = field(default=0, validator=instance_of(int)) + topness: int = field(default=0, validator=instance_of(int)) + baryon_number: int = field(default=0, validator=instance_of(int)) + electron_lepton_number: int = field(default=0, validator=instance_of(int)) + muon_lepton_number: int = field(default=0, validator=instance_of(int)) + tau_lepton_number: int = field(default=0, validator=instance_of(int)) + parity: Optional[Parity] = field( converter=optional(_to_parity), default=None ) - c_parity: Optional[Parity] = attr.ib( + c_parity: Optional[Parity] = field( converter=optional(_to_parity), default=None ) - g_parity: Optional[Parity] = attr.ib( + g_parity: Optional[Parity] = field( converter=optional(_to_parity), default=None ) @@ -242,11 +239,11 @@ def _repr_pretty_(self, p: "PrettyPrinter", cycle: bool) -> None: p.text(f"{class_name}(...)") else: with p.group(indent=2, open=f"{class_name}("): - for field in attr.fields(type(self)): - value = getattr(self, field.name) - if value != field.default: + for attribute in attrs.fields(type(self)): + value = getattr(self, attribute.name) + if value != attribute.default: p.breakable() - p.text(f"{field.name}=") + p.text(f"{attribute.name}=") if isinstance(value, Parity): p.text(_to_fraction(int(value), render_plus=True)) else: diff --git a/src/qrules/quantum_numbers.py b/src/qrules/quantum_numbers.py index 6137985f..1ca570a5 100644 --- a/src/qrules/quantum_numbers.py +++ b/src/qrules/quantum_numbers.py @@ -12,13 +12,14 @@ from functools import total_ordering from typing import Any, Generator, NewType, Optional, Union -import attr -from attr.validators import instance_of +import attrs +from attrs import field, frozen +from attrs.validators import instance_of from qrules._implementers import implement_pretty_repr -def _check_plus_minus(_: Any, __: attr.Attribute, value: Any) -> None: +def _check_plus_minus(_: Any, __: attrs.Attribute, value: Any) -> None: if not isinstance(value, int): raise TypeError( f"Input for {Parity.__name__} has to be of type {int.__name__}," @@ -29,9 +30,9 @@ def _check_plus_minus(_: Any, __: attr.Attribute, value: Any) -> None: @total_ordering -@attr.frozen(eq=False, hash=True, order=False, repr=False) +@frozen(eq=False, hash=True, order=False, repr=False) class Parity: - value: int = attr.ib(validator=[instance_of(int), _check_plus_minus]) + value: int = field(validator=[instance_of(int), _check_plus_minus]) def __eq__(self, other: object) -> bool: if isinstance(other, Parity): @@ -60,17 +61,17 @@ def _to_fraction(value: Union[float, int], render_plus: bool = False) -> str: return label -@attr.frozen(init=False) +@frozen(init=False) class EdgeQuantumNumbers: # pylint: disable=too-many-instance-attributes """Definition of quantum numbers for edges. This class defines the types that are used in the :mod:`.conservation_rules`, for instance in `.additive_quantum_number_rule`. You can also create data classes (see - `attr.s`) with data members that are typed as the data members of - `.EdgeQuantumNumbers` (see for example `.HelicityParityEdgeInput`) and use - them in conservation rules that satisfy the appropriate rule protocol (see - `.ConservationRule`, `.EdgeQNConservationRule`). + :func:`attrs.define`) with data members that are typed as the data members + of `.EdgeQuantumNumbers` (see for example `.HelicityParityEdgeInput`) and + use them in conservation rules that satisfy the appropriate rule protocol + (see `.ConservationRule`, `.EdgeQNConservationRule`). """ pid = NewType("pid", int) @@ -124,7 +125,7 @@ class EdgeQuantumNumbers: # pylint: disable=too-many-instance-attributes ] -@attr.frozen(init=False) +@frozen(init=False) class NodeQuantumNumbers: """Definition of quantum numbers for interaction nodes.""" @@ -164,7 +165,7 @@ def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: @implement_pretty_repr() -@attr.frozen(order=True) +@frozen(order=True) class InteractionProperties: """Immutable data structure containing interaction properties. @@ -181,19 +182,19 @@ class represents the properties that are carried collectively by the edges class serves as an interface to the user. """ - l_magnitude: Optional[int] = attr.ib( # L cannot be half integer + l_magnitude: Optional[int] = field( # L cannot be half integer default=None, converter=_to_optional_int ) - l_projection: Optional[int] = attr.ib( + l_projection: Optional[int] = field( default=None, converter=_to_optional_int ) - s_magnitude: Optional[float] = attr.ib( + s_magnitude: Optional[float] = field( default=None, converter=_to_optional_float ) - s_projection: Optional[float] = attr.ib( + s_projection: Optional[float] = field( default=None, converter=_to_optional_float ) - parity_prefactor: Optional[float] = attr.ib( + parity_prefactor: Optional[float] = field( default=None, converter=_to_optional_float ) diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 3b06d2fd..11789ce9 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -29,7 +29,8 @@ Union, ) -import attr +import attrs +from attrs import define, field, frozen from constraint import ( BacktrackingSolver, Constraint, @@ -58,17 +59,17 @@ @implement_pretty_repr() -@attr.define +@define class EdgeSettings: """Solver settings for a specific edge of a graph.""" - conservation_rules: Set[GraphElementRule] = attr.ib(factory=set) - rule_priorities: Dict[GraphElementRule, int] = attr.ib(factory=dict) - qn_domains: Dict[Any, list] = attr.ib(factory=dict) + conservation_rules: Set[GraphElementRule] = field(factory=set) + rule_priorities: Dict[GraphElementRule, int] = field(factory=dict) + qn_domains: Dict[Any, list] = field(factory=dict) @implement_pretty_repr() -@attr.define +@define class NodeSettings: """Container class for the interaction settings. @@ -82,28 +83,28 @@ class NodeSettings: - strength scale parameter (higher value means stronger force) """ - conservation_rules: Set[Rule] = attr.ib(factory=set) - rule_priorities: Dict[Rule, int] = attr.ib(factory=dict) - qn_domains: Dict[Any, list] = attr.ib(factory=dict) + conservation_rules: Set[Rule] = field(factory=set) + rule_priorities: Dict[Rule, int] = field(factory=dict) + qn_domains: Dict[Any, list] = field(factory=dict) interaction_strength: float = 1.0 @implement_pretty_repr() -@attr.define +@define class GraphSettings: - edge_settings: Dict[int, EdgeSettings] = attr.ib(factory=dict) - node_settings: Dict[int, NodeSettings] = attr.ib(factory=dict) + edge_settings: Dict[int, EdgeSettings] = field(factory=dict) + node_settings: Dict[int, NodeSettings] = field(factory=dict) @implement_pretty_repr() -@attr.define +@define class GraphElementProperties: - edge_props: Dict[int, GraphEdgePropertyMap] = attr.ib(factory=dict) - node_props: Dict[int, GraphNodePropertyMap] = attr.ib(factory=dict) + edge_props: Dict[int, GraphEdgePropertyMap] = field(factory=dict) + node_props: Dict[int, GraphNodePropertyMap] = field(factory=dict) @implement_pretty_repr() -@attr.frozen +@frozen class QNProblemSet: """Particle reaction problem set, defined as a graph like data structure. @@ -117,16 +118,16 @@ class QNProblemSet: topology """ - topology: Topology = attr.ib() - initial_facts: GraphElementProperties = attr.ib() - solving_settings: GraphSettings = attr.ib() + topology: Topology + initial_facts: GraphElementProperties + solving_settings: GraphSettings @implement_pretty_repr() -@attr.frozen +@frozen class QuantumNumberSolution: - node_quantum_numbers: Dict[int, GraphNodePropertyMap] = attr.ib() - edge_quantum_numbers: Dict[int, GraphEdgePropertyMap] = attr.ib() + node_quantum_numbers: Dict[int, GraphNodePropertyMap] + edge_quantum_numbers: Dict[int, GraphEdgePropertyMap] def _convert_violated_rules_to_names( @@ -174,21 +175,21 @@ def get_name(rule: Any) -> str: @implement_pretty_repr() -@attr.define(on_setattr=attr.setters.frozen) +@define(on_setattr=attrs.setters.frozen) class QNResult: """Defines a result to a problem set processed by the solving code.""" - solutions: List[QuantumNumberSolution] = attr.ib(factory=list) - not_executed_node_rules: Dict[int, Set[str]] = attr.ib( + solutions: List[QuantumNumberSolution] = field(factory=list) + not_executed_node_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) - violated_node_rules: Dict[int, Set[str]] = attr.ib( + violated_node_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) - not_executed_edge_rules: Dict[int, Set[str]] = attr.ib( + not_executed_edge_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) - violated_edge_rules: Dict[int, Set[str]] = attr.ib( + violated_edge_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) @@ -275,7 +276,7 @@ def _merge_particle_candidates_with_solutions( for k, v in current_new_solution.edge_quantum_numbers.items() } new_edge_qns[int_edge_id].update(particle_edge) - temp_solution = attr.evolve( + temp_solution = attrs.evolve( current_new_solution, edge_quantum_numbers=new_edge_qns, ) @@ -461,18 +462,18 @@ def _create_variable_string( return str(element_id) + "-" + qn_type.__name__ -@attr.define +@define class _VariableContainer: - ingoing_edge_variables: Set[_EdgeVariableInfo] = attr.ib(factory=set) - fixed_ingoing_edge_variables: Dict[int, GraphEdgePropertyMap] = attr.ib( + ingoing_edge_variables: Set[_EdgeVariableInfo] = field(factory=set) + fixed_ingoing_edge_variables: Dict[int, GraphEdgePropertyMap] = field( factory=dict ) - outgoing_edge_variables: Set[_EdgeVariableInfo] = attr.ib(factory=set) - fixed_outgoing_edge_variables: Dict[int, GraphEdgePropertyMap] = attr.ib( + outgoing_edge_variables: Set[_EdgeVariableInfo] = field(factory=set) + fixed_outgoing_edge_variables: Dict[int, GraphEdgePropertyMap] = field( factory=dict ) - node_variables: Set[_NodeVariableInfo] = attr.ib(factory=set) - fixed_node_variables: GraphNodePropertyMap = attr.ib(factory=dict) + node_variables: Set[_NodeVariableInfo] = field(factory=set) + fixed_node_variables: GraphNodePropertyMap = field(factory=dict) class CSPSolver(Solver): diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 71a37e04..16f409e8 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -37,7 +37,9 @@ ValuesView, ) -import attr +import attrs +from attrs import define, field, frozen +from attrs.validators import instance_of from qrules._implementers import implement_pretty_repr @@ -143,14 +145,14 @@ def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: return int(optional_int) -@attr.frozen(order=True) +@frozen(order=True) class Edge: """Struct-like definition of an edge, used in `Topology`.""" - originating_node_id: Optional[int] = attr.ib( + originating_node_id: Optional[int] = field( default=None, converter=_to_optional_int ) - ending_node_id: Optional[int] = attr.ib( + ending_node_id: Optional[int] = field( default=None, converter=_to_optional_int ) @@ -167,7 +169,7 @@ def _to_frozenset(iterable: Iterable[int]) -> FrozenSet[int]: @implement_pretty_repr() -@attr.frozen(order=True) +@frozen(order=True) class Topology: """Directed Feynman-like graph without edge or node properties. @@ -178,12 +180,12 @@ class Topology: like a Feynman-diagram. """ - nodes: FrozenSet[int] = attr.ib(converter=_to_frozenset) - edges: FrozenDict[int, Edge] = attr.ib(converter=FrozenDict) + nodes: FrozenSet[int] = field(converter=_to_frozenset) + edges: FrozenDict[int, Edge] = field(converter=FrozenDict) - incoming_edge_ids: FrozenSet[int] = attr.ib(init=False, repr=False) - outgoing_edge_ids: FrozenSet[int] = attr.ib(init=False, repr=False) - intermediate_edge_ids: FrozenSet[int] = attr.ib(init=False, repr=False) + incoming_edge_ids: FrozenSet[int] = field(init=False, repr=False) + outgoing_edge_ids: FrozenSet[int] = field(init=False, repr=False) + intermediate_edge_ids: FrozenSet[int] = field(init=False, repr=False) def __attrs_post_init__(self) -> None: self.__verify() @@ -349,7 +351,7 @@ def relabel_edges(self, old_to_new_id: Mapping[int, int]) -> "Topology": new_edges = { old_to_new_id.get(i, i): edge for i, edge in self.edges.items() } - return attr.evolve(self, edges=new_edges) + return attrs.evolve(self, edges=new_edges) def swap_edges(self, edge_id1: int, edge_id2: int) -> "Topology": return self.relabel_edges({edge_id1: edge_id2, edge_id2: edge_id1}) @@ -374,10 +376,10 @@ def __get_originating_node(edge_id: int) -> Optional[int]: ] -@attr.define(kw_only=True) +@define(kw_only=True) class _MutableTopology: - edges: Dict[int, Edge] = attr.ib(factory=dict, converter=dict) - nodes: Set[int] = attr.ib(factory=set, converter=set) + edges: Dict[int, Edge] = field(factory=dict, converter=dict) + nodes: Set[int] = field(factory=set, converter=set) def freeze(self) -> Topology: return Topology( @@ -457,16 +459,12 @@ def attach_edges_to_node_outgoing( ) -@attr.define +@define class InteractionNode: """Helper class for the `.SimpleStateTransitionTopologyBuilder`.""" - number_of_ingoing_edges: int = attr.ib( - validator=attr.validators.instance_of(int) - ) - number_of_outgoing_edges: int = attr.ib( - validator=attr.validators.instance_of(int) - ) + number_of_ingoing_edges: int = field(validator=instance_of(int)) + number_of_outgoing_edges: int = field(validator=instance_of(int)) def __attrs_post_init__(self) -> None: if self.number_of_ingoing_edges < 1: diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 3a9e7a0a..75f03346 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -26,8 +26,9 @@ overload, ) -import attr -from attr.validators import instance_of +import attrs +from attrs import define, field, frozen +from attrs.validators import instance_of from tqdm.auto import tqdm from qrules._implementers import implement_pretty_repr @@ -103,18 +104,18 @@ class SolvingMode(Enum): @implement_pretty_repr() -@attr.define(on_setattr=attr.setters.frozen) +@define(on_setattr=attrs.setters.frozen) class ExecutionInfo: - not_executed_node_rules: Dict[int, Set[str]] = attr.ib( + not_executed_node_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) - violated_node_rules: Dict[int, Set[str]] = attr.ib( + violated_node_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) - not_executed_edge_rules: Dict[int, Set[str]] = attr.ib( + not_executed_edge_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) - violated_edge_rules: Dict[int, Set[str]] = attr.ib( + violated_edge_rules: Dict[int, Set[str]] = field( factory=lambda: defaultdict(set) ) @@ -146,14 +147,14 @@ def clear(self) -> None: self.violated_edge_rules.clear() -@attr.frozen +@frozen class _SolutionContainer: """Defines a result of a `.ProblemSet`.""" - solutions: List[StateTransitionGraph[ParticleWithSpin]] = attr.ib( + solutions: List[StateTransitionGraph[ParticleWithSpin]] = field( factory=list ) - execution_info: ExecutionInfo = attr.ib(default=ExecutionInfo()) + execution_info: ExecutionInfo = field(default=ExecutionInfo()) def __attrs_post_init__(self) -> None: if self.solutions and ( @@ -180,7 +181,7 @@ def extend( @implement_pretty_repr() -@attr.define +@define class ProblemSet: """Particle reaction problem set, defined as a graph like data structure. @@ -192,9 +193,9 @@ class ProblemSet: rules and the quantum number domains. """ - topology: Topology = attr.ib() - initial_facts: InitialFacts = attr.ib() - solving_settings: GraphSettings = attr.ib() + topology: Topology + initial_facts: InitialFacts + solving_settings: GraphSettings def to_qn_problem_set(self) -> QNProblemSet: node_props = { @@ -736,20 +737,20 @@ def _strip_spin(state_definition: Sequence[StateDefinition]) -> List[str]: @implement_pretty_repr() -@attr.frozen(order=True) +@frozen(order=True) class State: - particle: Particle = attr.ib(validator=instance_of(Particle)) - spin_projection: float = attr.ib(converter=_to_float) + particle: Particle = field(validator=instance_of(Particle)) + spin_projection: float = field(converter=_to_float) @implement_pretty_repr() -@attr.frozen(order=True) +@frozen(order=True) class StateTransition: """Frozen instance of a `.StateTransitionGraph` of a particle with spin.""" - topology: Topology = attr.ib(validator=instance_of(Topology)) - states: FrozenDict[int, State] = attr.ib(converter=FrozenDict) - interactions: FrozenDict[int, InteractionProperties] = attr.ib( + topology: Topology = field(validator=instance_of(Topology)) + states: FrozenDict[int, State] = field(converter=FrozenDict) + interactions: FrozenDict[int, InteractionProperties] = field( converter=FrozenDict ) @@ -824,16 +825,16 @@ def _to_sorted_tuple( return tuple(sorted(iterable)) -@attr.frozen +@frozen class StateTransitionCollection(abc.Sequence): """`.StateTransition` instances with the same `.Topology` and edge IDs.""" - transitions: Tuple[StateTransition, ...] = attr.ib( + transitions: Tuple[StateTransition, ...] = field( converter=_to_sorted_tuple ) - topology: Topology = attr.ib(init=False, repr=False) - initial_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) - final_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) + topology: Topology = field(init=False, repr=False) + initial_state: FrozenDict[int, Particle] = field(init=False, repr=False) + final_state: FrozenDict[int, Particle] = field(init=False, repr=False) def __attrs_post_init__(self) -> None: if not any(self.transitions): @@ -935,19 +936,19 @@ def _to_tuple( return tuple(iterable) -@attr.frozen(eq=False, hash=True) +@frozen(eq=False, hash=True) class ReactionInfo: """`StateTransitionCollection` instances, grouped by `.Topology`.""" - transition_groups: Tuple[StateTransitionCollection, ...] = attr.ib( + transition_groups: Tuple[StateTransitionCollection, ...] = field( converter=_to_tuple ) - transitions: Tuple[StateTransition, ...] = attr.ib( + transitions: Tuple[StateTransition, ...] = field( init=False, repr=False, eq=False ) - initial_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) - final_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) - formalism: str = attr.ib(validator=instance_of(str)) + initial_state: FrozenDict[int, Particle] = field(init=False, repr=False) + final_state: FrozenDict[int, Particle] = field(init=False, repr=False) + formalism: str = field(validator=instance_of(str)) def __attrs_post_init__(self) -> None: if len(self.transition_groups) == 0: diff --git a/tests/unit/conservation_rules/test_duck_typing.py b/tests/unit/conservation_rules/test_duck_typing.py index dc703a07..f2bff129 100644 --- a/tests/unit/conservation_rules/test_duck_typing.py +++ b/tests/unit/conservation_rules/test_duck_typing.py @@ -8,7 +8,7 @@ import inspect from typing import Set, Type -import attr +import attrs from qrules import conservation_rules from qrules.particle import Particle @@ -101,8 +101,8 @@ def test_get_members(): def __get_members(class_type: Type) -> Set[str]: use_attrs = class_type not in {EdgeQuantumNumbers, NodeQuantumNumbers} - if use_attrs and attr.has(class_type): - return {f.name for f in attr.fields(class_type)} + if use_attrs and attrs.has(class_type): + return {f.name for f in attrs.fields(class_type)} return { a.name for a in inspect.classify_class_attrs(class_type) diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index 79c62ec9..be7b8638 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -14,8 +14,9 @@ def through_dict(instance): + # Check JSON serialization asdict = io.asdict(instance) - asdict = json.loads(json.dumps(asdict)) # check JSON serialization + asdict = json.loads(json.dumps(asdict, cls=io.JSONSetEncoder)) return io.fromdict(asdict) diff --git a/tests/unit/test_particle.py b/tests/unit/test_particle.py index 3fcffc1b..29e2fc11 100644 --- a/tests/unit/test_particle.py +++ b/tests/unit/test_particle.py @@ -4,7 +4,7 @@ from copy import deepcopy import pytest -from attr.exceptions import FrozenInstanceError +from attrs.exceptions import FrozenInstanceError from IPython.lib.pretty import pretty from qrules.particle import ( diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 9d3a19cd..688ab640 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -2,8 +2,8 @@ # pyright: reportUnusedImport=false import typing -import attr import pytest +from attrs.exceptions import FrozenInstanceError from IPython.lib.pretty import pretty from qrules.topology import ( # noqa: F401 @@ -58,13 +58,13 @@ def test_get_connected_nodes(self): @typing.no_type_check def test_immutability(self): edge = Edge(1, 2) - with pytest.raises(attr.exceptions.FrozenInstanceError): + with pytest.raises(FrozenInstanceError): edge.originating_node_id = None - with pytest.raises(attr.exceptions.FrozenInstanceError): + with pytest.raises(FrozenInstanceError): edge.originating_node_id += 1 - with pytest.raises(attr.exceptions.FrozenInstanceError): + with pytest.raises(FrozenInstanceError): edge.ending_node_id = None - with pytest.raises(attr.exceptions.FrozenInstanceError): + with pytest.raises(FrozenInstanceError): edge.ending_node_id += 1 @@ -220,13 +220,13 @@ def test_getters(self, two_to_three_decay: Topology): @typing.no_type_check def test_immutability(self, two_to_three_decay: Topology): - with pytest.raises(attr.exceptions.FrozenInstanceError): + with pytest.raises(FrozenInstanceError): two_to_three_decay.edges = {0: Edge(None, None)} with pytest.raises(TypeError): two_to_three_decay.edges[0] = Edge(None, None) - with pytest.raises(attr.exceptions.FrozenInstanceError): + with pytest.raises(FrozenInstanceError): two_to_three_decay.edges[0].ending_node_id = None - with pytest.raises(attr.exceptions.FrozenInstanceError): + with pytest.raises(FrozenInstanceError): two_to_three_decay.nodes = {0, 1} with pytest.raises(AttributeError): two_to_three_decay.nodes.add(2)