From eb73c7d8af4dc6aeb2cf4e47e688a40b62beedf4 Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Mon, 6 Dec 2021 14:58:25 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20constructor=20of=20ba?= =?UTF-8?q?se=20data=20types=20(#5165)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 👌 IMPROVE: constructor of base data types Adapt the constructor of the `Dict` and `List` data types so the value no longer needs to be provided as a keyword argument, but can simply be passed as the first input. Remove the `*args` from the `BaseType` constructor, instead specifying the `value` input argument. Otherwise users could just pass multiple positional arguments and it would only use the first without raising an error. Also fixes two issues with the `List` class: * The `remove` method was not working as prescribed. Instead of removing the first element in the list with the specified value, it simply deleted the element with index `value`. * The `set_list()` method didn't make a copy of the `list` input, so any modification done to the `List` instance would also affect the original `list`. If the same list is used to initialise several `List` instances, adapting one would affect the other. Finally, add tests for the `List` class. * Refactor data tests --- aiida/orm/nodes/data/base.py | 9 +- aiida/orm/nodes/data/dict.py | 12 +- aiida/orm/nodes/data/list.py | 16 +- tests/orm/{data => nodes}/__init__.py | 0 tests/orm/{node => nodes/data}/__init__.py | 0 tests/orm/{ => nodes}/data/test_array.py | 0 .../orm/{ => nodes}/data/test_array_bands.py | 0 tests/orm/nodes/data/test_base.py | 207 +++++++++++++ tests/orm/{ => nodes}/data/test_cif.py | 0 tests/orm/{ => nodes}/data/test_data.py | 0 tests/orm/{ => nodes}/data/test_dict.py | 25 +- tests/orm/{ => nodes}/data/test_folder.py | 0 tests/orm/{ => nodes}/data/test_jsonable.py | 0 tests/orm/{ => nodes}/data/test_kpoints.py | 0 tests/orm/nodes/data/test_list.py | 216 +++++++++++++ tests/orm/{ => nodes}/data/test_orbital.py | 0 tests/orm/{ => nodes}/data/test_remote.py | 0 .../orm/{ => nodes}/data/test_remote_stash.py | 0 tests/orm/{ => nodes}/data/test_singlefile.py | 0 tests/orm/{ => nodes}/data/test_structure.py | 0 .../{ => nodes}/data/test_to_aiida_type.py | 0 tests/orm/{ => nodes}/data/test_trajectory.py | 0 tests/orm/{ => nodes}/data/test_upf.py | 0 tests/orm/{node => nodes}/test_calcjob.py | 0 tests/orm/{node => nodes}/test_node.py | 0 tests/orm/{node => nodes}/test_repository.py | 0 tests/test_base_dataclasses.py | 290 ------------------ 27 files changed, 459 insertions(+), 316 deletions(-) rename tests/orm/{data => nodes}/__init__.py (100%) rename tests/orm/{node => nodes/data}/__init__.py (100%) rename tests/orm/{ => nodes}/data/test_array.py (100%) rename tests/orm/{ => nodes}/data/test_array_bands.py (100%) create mode 100644 tests/orm/nodes/data/test_base.py rename tests/orm/{ => nodes}/data/test_cif.py (100%) rename tests/orm/{ => nodes}/data/test_data.py (100%) rename tests/orm/{ => nodes}/data/test_dict.py (87%) rename tests/orm/{ => nodes}/data/test_folder.py (100%) rename tests/orm/{ => nodes}/data/test_jsonable.py (100%) rename tests/orm/{ => nodes}/data/test_kpoints.py (100%) create mode 100644 tests/orm/nodes/data/test_list.py rename tests/orm/{ => nodes}/data/test_orbital.py (100%) rename tests/orm/{ => nodes}/data/test_remote.py (100%) rename tests/orm/{ => nodes}/data/test_remote_stash.py (100%) rename tests/orm/{ => nodes}/data/test_singlefile.py (100%) rename tests/orm/{ => nodes}/data/test_structure.py (100%) rename tests/orm/{ => nodes}/data/test_to_aiida_type.py (100%) rename tests/orm/{ => nodes}/data/test_trajectory.py (100%) rename tests/orm/{ => nodes}/data/test_upf.py (100%) rename tests/orm/{node => nodes}/test_calcjob.py (100%) rename tests/orm/{node => nodes}/test_node.py (100%) rename tests/orm/{node => nodes}/test_repository.py (100%) delete mode 100644 tests/test_base_dataclasses.py diff --git a/aiida/orm/nodes/data/base.py b/aiida/orm/nodes/data/base.py index 86858e14ce..070296ad0d 100644 --- a/aiida/orm/nodes/data/base.py +++ b/aiida/orm/nodes/data/base.py @@ -24,7 +24,7 @@ def to_aiida_type(value): class BaseType(Data): """`Data` sub class to be used as a base for data containers that represent base python data types.""" - def __init__(self, *args, **kwargs): + def __init__(self, value=None, **kwargs): try: getattr(self, '_type') except AttributeError: @@ -32,12 +32,7 @@ def __init__(self, *args, **kwargs): super().__init__(**kwargs) - try: - value = args[0] - except IndexError: - value = self._type() # pylint: disable=no-member - - self.value = value + self.value = value or self._type() # pylint: disable=no-member @property def value(self): diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index 73820513af..6cd542ca65 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -46,17 +46,17 @@ class Dict(Data): Finally, all dictionary mutations will be forbidden once the node is stored. """ - def __init__(self, **kwargs): - """Store a dictionary as a `Node` instance. + def __init__(self, value=None, **kwargs): + """Initialise a ``Dict`` node instance. - Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a `ValueError` + Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a ``ValueError`` will be raised. Initial attributes can be changed, deleted or added as long as the node is not stored. - :param dict: the dictionary to set + :param value: dictionary to initialise the ``Dict`` node from """ - dictionary = kwargs.pop('dict', None) + dictionary = value or kwargs.pop('dict', None) super().__init__(**kwargs) if dictionary: self.set_dict(dictionary) @@ -135,4 +135,4 @@ def dict(self): @to_aiida_type.register(dict) def _(value): - return Dict(dict=value) + return Dict(value) diff --git a/aiida/orm/nodes/data/list.py b/aiida/orm/nodes/data/list.py index 37ae846f09..cb05920a48 100644 --- a/aiida/orm/nodes/data/list.py +++ b/aiida/orm/nodes/data/list.py @@ -21,8 +21,12 @@ class List(Data, MutableSequence): _LIST_KEY = 'list' - def __init__(self, **kwargs): - data = kwargs.pop('list', []) + def __init__(self, value=None, **kwargs): + """Initialise a ``List`` node instance. + + :param value: list to initialise the ``List`` node from + """ + data = value or kwargs.pop('list', []) super().__init__(**kwargs) self.set_list(data) @@ -75,7 +79,11 @@ def insert(self, i, value): # pylint: disable=arguments-renamed self.set_list(data) def remove(self, value): - del self[value] + data = self.get_list() + item = data.remove(value) + if not self._using_list_reference(): + self.set_list(data) + return item def pop(self, **kwargs): # pylint: disable=arguments-differ """Remove and return item at index (default last).""" @@ -123,7 +131,7 @@ def set_list(self, data): """ if not isinstance(data, list): raise TypeError('Must supply list type') - self.set_attribute(self._LIST_KEY, data) + self.set_attribute(self._LIST_KEY, data.copy()) def _using_list_reference(self): """ diff --git a/tests/orm/data/__init__.py b/tests/orm/nodes/__init__.py similarity index 100% rename from tests/orm/data/__init__.py rename to tests/orm/nodes/__init__.py diff --git a/tests/orm/node/__init__.py b/tests/orm/nodes/data/__init__.py similarity index 100% rename from tests/orm/node/__init__.py rename to tests/orm/nodes/data/__init__.py diff --git a/tests/orm/data/test_array.py b/tests/orm/nodes/data/test_array.py similarity index 100% rename from tests/orm/data/test_array.py rename to tests/orm/nodes/data/test_array.py diff --git a/tests/orm/data/test_array_bands.py b/tests/orm/nodes/data/test_array_bands.py similarity index 100% rename from tests/orm/data/test_array_bands.py rename to tests/orm/nodes/data/test_array_bands.py diff --git a/tests/orm/nodes/data/test_base.py b/tests/orm/nodes/data/test_base.py new file mode 100644 index 0000000000..adb564f42e --- /dev/null +++ b/tests/orm/nodes/data/test_base.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name +"""Tests for :class:`aiida.orm.nodes.data.base.BaseType` classes.""" + +import operator + +import pytest + +from aiida.orm import Bool, Float, Int, NumericType, Str, load_node + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize( + 'node_type, default, value', [ + (Bool, False, True), + (Int, 0, 5), + (Float, 0.0, 5.5), + (Str, '', 'a'), + ] +) +def test_create(node_type, default, value): + """Test the creation of the ``BaseType`` nodes.""" + + node = node_type() + assert node.value == default + + node = node_type(value) + assert node.value == value + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type', [Bool, Float, Int, Str]) +def test_store_load(node_type): + """Test ``BaseType`` node storing and loading.""" + node = node_type() + node.store() + loaded = load_node(node.pk) + assert node.value == loaded.value + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_modulo(): + """Test ``Int`` modulus operation.""" + term_a = Int(12) + term_b = Int(10) + + assert term_a % term_b == 2 + assert isinstance(term_a % term_b, NumericType) + assert term_a % 10 == 2 + assert isinstance(term_a % 10, NumericType) + assert 12 % term_b == 2 + assert isinstance(12 % term_b, NumericType) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, a, b', [ + (Int, 3, 5), + (Float, 1.2, 5.5), +]) +def test_add(node_type, a, b): + """Test addition for ``Int`` and ``Float`` nodes.""" + node_a = node_type(a) + node_b = node_type(b) + + result = node_a + node_b + assert isinstance(result, node_type) + assert result.value == a + b + + # Node and native (both ways) + result = node_a + b + assert isinstance(result, node_type) + assert result.value == a + b + + result = a + node_b + assert isinstance(result, node_type) + assert result.value == a + b + + # Inplace + result = node_type(a) + result += node_b + assert isinstance(result, node_type) + assert result.value == a + b + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, a, b', [ + (Int, 3, 5), + (Float, 1.2, 5.5), +]) +def test_multiplication(node_type, a, b): + """Test floats multiplication.""" + node_a = node_type(a) + node_b = node_type(b) + + # Check multiplication + result = node_a * node_b + assert isinstance(result, node_type) + assert result.value == a * b + + # Check multiplication Node and native (both ways) + result = node_a * b + assert isinstance(result, node_type) + assert result.value == a * b + + result = a * node_b + assert isinstance(result, node_type) + assert result.value == a * b + + # Inplace + result = node_type(a) + result *= node_b + assert isinstance(result, node_type) + assert result.value == a * b + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, a, b', [ + (Int, 3, 5), + (Float, 1.2, 5.5), +]) +@pytest.mark.usefixtures('clear_database_before_test') +def test_division(node_type, a, b): + """Test the ``BaseType`` normal division operator.""" + node_a = node_type(a) + node_b = node_type(b) + + result = node_a / node_b + assert result == a / b + assert isinstance(result, Float) # Should be a `Float` for both node types + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, a, b', [ + (Int, 3, 5), + (Float, 1.2, 5.5), +]) +@pytest.mark.usefixtures('clear_database_before_test') +def test_division_integer(node_type, a, b): + """Test the ``Int`` integer division operator.""" + node_a = node_type(a) + node_b = node_type(b) + + result = node_a // node_b + assert result == a // b + assert isinstance(result, node_type) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, base, power', [ + (Int, 5, 2), + (Float, 3.5, 3), +]) +def test_power(node_type, base, power): + """Test power operator.""" + node_base = node_type(base) + node_power = node_type(power) + + result = node_base**node_power + assert result == base**power + assert isinstance(result, node_type) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, a, b', [ + (Int, 5, 2), + (Float, 3.5, 3), +]) +def test_modulus(node_type, a, b): + """Test modulus operator.""" + node_a = node_type(a) + node_b = node_type(b) + + assert node_a % node_b == a % b + assert isinstance(node_a % node_b, node_type) + + assert node_a % b == a % b + assert isinstance(node_a % b, node_type) + + assert a % node_b == a % b + assert isinstance(a % node_b, node_type) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize( + 'opera', [ + operator.add, operator.mul, operator.pow, operator.lt, operator.le, operator.gt, operator.ge, operator.iadd, + operator.imul + ] +) +def test_operator(opera): + """Test operations between Int and Float objects.""" + node_a = Float(2.2) + node_b = Int(3) + + for node_x, node_y in [(node_a, node_b), (node_b, node_a)]: + res = opera(node_x, node_y) + c_val = opera(node_x.value, node_y.value) + assert res._type == type(c_val) # pylint: disable=protected-access + assert res == opera(node_x.value, node_y.value) diff --git a/tests/orm/data/test_cif.py b/tests/orm/nodes/data/test_cif.py similarity index 100% rename from tests/orm/data/test_cif.py rename to tests/orm/nodes/data/test_cif.py diff --git a/tests/orm/data/test_data.py b/tests/orm/nodes/data/test_data.py similarity index 100% rename from tests/orm/data/test_data.py rename to tests/orm/nodes/data/test_data.py diff --git a/tests/orm/data/test_dict.py b/tests/orm/nodes/data/test_dict.py similarity index 87% rename from tests/orm/data/test_dict.py rename to tests/orm/nodes/data/test_dict.py index 36ee77da4f..7a27b91fe6 100644 --- a/tests/orm/data/test_dict.py +++ b/tests/orm/nodes/data/test_dict.py @@ -22,21 +22,21 @@ def dictionary(): @pytest.mark.usefixtures('clear_database_before_test') def test_keys(dictionary): """Test the ``keys`` method.""" - node = Dict(dict=dictionary) + node = Dict(dictionary) assert sorted(node.keys()) == sorted(dictionary.keys()) @pytest.mark.usefixtures('clear_database_before_test') def test_get_dict(dictionary): """Test the ``get_dict`` method.""" - node = Dict(dict=dictionary) + node = Dict(dictionary) assert node.get_dict() == dictionary @pytest.mark.usefixtures('clear_database_before_test') def test_dict_property(dictionary): """Test the ``dict`` property.""" - node = Dict(dict=dictionary) + node = Dict(dictionary) assert node.dict.value == dictionary['value'] assert node.dict.nested == dictionary['nested'] @@ -44,7 +44,7 @@ def test_dict_property(dictionary): @pytest.mark.usefixtures('clear_database_before_test') def test_get_item(dictionary): """Test the ``__getitem__`` method.""" - node = Dict(dict=dictionary) + node = Dict(dictionary) assert node['value'] == dictionary['value'] assert node['nested'] == dictionary['nested'] @@ -56,7 +56,7 @@ def test_set_item(dictionary): * ``__setitem__`` directly on the node * ``__setattr__`` through the ``AttributeManager`` returned by the ``dict`` property """ - node = Dict(dict=dictionary) + node = Dict(dictionary) node['value'] = 2 assert node['value'] == 2 @@ -72,7 +72,7 @@ def test_correct_raises(dictionary): * ``node['inexistent']`` should raise ``KeyError`` * ``node.dict.inexistent`` should raise ``AttributeError`` """ - node = Dict(dict=dictionary) + node = Dict(dictionary) with pytest.raises(KeyError): _ = node['inexistent_key'] @@ -89,8 +89,8 @@ def test_eq(dictionary): compare equal to another node that has the same content. This is a hot issue and is being discussed in the following ticket: https://github.com/aiidateam/aiida-core/issues/1917 """ - node = Dict(dict=dictionary) - clone = Dict(dict=dictionary) + node = Dict(dictionary) + clone = Dict(dictionary) assert node is node # pylint: disable=comparison-with-itself assert node == dictionary @@ -101,8 +101,15 @@ def test_eq(dictionary): # wouldn't happen unless, by accident, two different nodes get the same UUID, the probability of which is minimal. # Note that we have to set the UUID directly through the database model instance of the backend entity, since it is # forbidden to change it through the front-end or backend entity instance, for good reasons. - other = Dict(dict={}) + other = Dict({}) other.backend_entity._dbmodel.uuid = node.uuid # pylint: disable=protected-access assert other.uuid == node.uuid assert other.dict != node.dict assert node == other + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_initialise_with_dict_kwarg(dictionary): + """Test that the ``Dict`` node can be initialized with the ``dict`` keyword argument for backwards compatibility.""" + node = Dict(dict=dictionary) + assert sorted(node.keys()) == sorted(dictionary.keys()) diff --git a/tests/orm/data/test_folder.py b/tests/orm/nodes/data/test_folder.py similarity index 100% rename from tests/orm/data/test_folder.py rename to tests/orm/nodes/data/test_folder.py diff --git a/tests/orm/data/test_jsonable.py b/tests/orm/nodes/data/test_jsonable.py similarity index 100% rename from tests/orm/data/test_jsonable.py rename to tests/orm/nodes/data/test_jsonable.py diff --git a/tests/orm/data/test_kpoints.py b/tests/orm/nodes/data/test_kpoints.py similarity index 100% rename from tests/orm/data/test_kpoints.py rename to tests/orm/nodes/data/test_kpoints.py diff --git a/tests/orm/nodes/data/test_list.py b/tests/orm/nodes/data/test_list.py new file mode 100644 index 0000000000..dd7f2309ce --- /dev/null +++ b/tests/orm/nodes/data/test_list.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=redefined-outer-name +"""Tests for :class:`aiida.orm.nodes.data.list.List` class.""" +import pytest + +from aiida.common.exceptions import ModificationNotAllowed +from aiida.orm import List, load_node + + +@pytest.fixture +def listing(): + return ['a', 2, True] + + +@pytest.fixture +def int_listing(): + return [2, 1, 3] + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_creation(): + """Test the creation of an empty ``List`` node.""" + node = List() + assert len(node) == 0 + with pytest.raises(IndexError): + node[0] # pylint: disable=pointless-statement + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_mutability(): + """Test list's mutability before and after storage.""" + node = List() + node.append(5) + node.store() + + # Test all mutable calls are now disallowed + with pytest.raises(ModificationNotAllowed): + node.append(5) + with pytest.raises(ModificationNotAllowed): + node.extend([5]) + with pytest.raises(ModificationNotAllowed): + node.insert(0, 2) + with pytest.raises(ModificationNotAllowed): + node.remove(5) + with pytest.raises(ModificationNotAllowed): + node.pop() + with pytest.raises(ModificationNotAllowed): + node.sort() + with pytest.raises(ModificationNotAllowed): + node.reverse() + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_store_load(listing): + """Test load_node on just stored object.""" + node = List(listing) + node.store() + + node_loaded = load_node(node.pk) + assert node.get_list() == node_loaded.get_list() + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_special_methods(listing): + """Test the special methods of the ``List`` class.""" + node = List(list=listing) + + # __getitem__ + for i, value in enumerate(listing): + assert node[i] == value + + # __setitem__ + node[0] = 'b' + assert node[0] == 'b' + + # __delitem__ + del node[0] + assert node.get_list() == listing[1:] + + # __len__ + assert len(node) == 2 + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_equality(listing): + """Test that two ``List`` nodes with equal content compare equal.""" + node1 = List(list=listing) + node2 = List(list=listing) + + assert node1 == node2 + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_append(listing): + """Test the ``List.append()`` method.""" + + def do_checks(node): + assert len(node) == 1 + assert node[0] == 4 + + node = List() + node.append(4) + do_checks(node) + + # Try the same after storing + node.store() + do_checks(node) + + node = List(list=listing) + node.append('more') + assert node[-1] == 'more' + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_extend(listing): + """Test extend() member function.""" + + def do_checks(node, lst): + assert len(node) == len(lst) + # Do an element wise comparison + for lst_el, node_el in zip(lst, node): + assert lst_el == node_el + + node = List() + node.extend(listing) + do_checks(node, listing) + + # Further extend + node.extend(listing) + do_checks(node, listing * 2) + + # Now try after storing + node.store() + do_checks(node, listing * 2) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_insert(listing): + """Test the ``List.insert()`` method.""" + node = List(list=listing) + node.insert(1, 'new') + assert node[1] == 'new' + assert len(node) == 4 + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_remove(listing): + """Test the ``List.remove()`` method.""" + node = List(list=listing) + node.remove(1) + listing.remove(1) + assert node.get_list() == listing + + with pytest.raises(ValueError, match=r'list.remove\(x\): x not in list'): + node.remove('non-existent') + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_pop(listing): + """Test the ``List.pop()`` method.""" + node = List(list=listing) + node.pop() + assert node.get_list() == listing[:-1] + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_index(listing): + """Test the ``List.index()`` method.""" + node = List(list=listing) + + assert node.index(True) == listing.index(True) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_count(listing): + """Test the ``List.count()`` method.""" + node = List(list=listing) + for value in listing: + assert node.count(value) == listing.count(value) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_sort(listing, int_listing): + """Test the ``List.sort()`` method.""" + node = List(list=int_listing) + node.sort() + int_listing.sort() + assert node.get_list() == int_listing + + node = List(list=listing) + with pytest.raises(TypeError, match=r"'<' not supported between instances of 'int' and 'str'"): + node.sort() + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_reverse(listing): + """Test the ``List.reverse()`` method.""" + node = List(list=listing) + node.reverse() + listing.reverse() + assert node.get_list() == listing + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_initialise_with_list_kwarg(listing): + """Test that the ``List`` node can be initialized with the ``list`` keyword argument for backwards compatibility.""" + node = List(list=listing) + assert node.get_list() == listing diff --git a/tests/orm/data/test_orbital.py b/tests/orm/nodes/data/test_orbital.py similarity index 100% rename from tests/orm/data/test_orbital.py rename to tests/orm/nodes/data/test_orbital.py diff --git a/tests/orm/data/test_remote.py b/tests/orm/nodes/data/test_remote.py similarity index 100% rename from tests/orm/data/test_remote.py rename to tests/orm/nodes/data/test_remote.py diff --git a/tests/orm/data/test_remote_stash.py b/tests/orm/nodes/data/test_remote_stash.py similarity index 100% rename from tests/orm/data/test_remote_stash.py rename to tests/orm/nodes/data/test_remote_stash.py diff --git a/tests/orm/data/test_singlefile.py b/tests/orm/nodes/data/test_singlefile.py similarity index 100% rename from tests/orm/data/test_singlefile.py rename to tests/orm/nodes/data/test_singlefile.py diff --git a/tests/orm/data/test_structure.py b/tests/orm/nodes/data/test_structure.py similarity index 100% rename from tests/orm/data/test_structure.py rename to tests/orm/nodes/data/test_structure.py diff --git a/tests/orm/data/test_to_aiida_type.py b/tests/orm/nodes/data/test_to_aiida_type.py similarity index 100% rename from tests/orm/data/test_to_aiida_type.py rename to tests/orm/nodes/data/test_to_aiida_type.py diff --git a/tests/orm/data/test_trajectory.py b/tests/orm/nodes/data/test_trajectory.py similarity index 100% rename from tests/orm/data/test_trajectory.py rename to tests/orm/nodes/data/test_trajectory.py diff --git a/tests/orm/data/test_upf.py b/tests/orm/nodes/data/test_upf.py similarity index 100% rename from tests/orm/data/test_upf.py rename to tests/orm/nodes/data/test_upf.py diff --git a/tests/orm/node/test_calcjob.py b/tests/orm/nodes/test_calcjob.py similarity index 100% rename from tests/orm/node/test_calcjob.py rename to tests/orm/nodes/test_calcjob.py diff --git a/tests/orm/node/test_node.py b/tests/orm/nodes/test_node.py similarity index 100% rename from tests/orm/node/test_node.py rename to tests/orm/nodes/test_node.py diff --git a/tests/orm/node/test_repository.py b/tests/orm/nodes/test_repository.py similarity index 100% rename from tests/orm/node/test_repository.py rename to tests/orm/nodes/test_repository.py diff --git a/tests/test_base_dataclasses.py b/tests/test_base_dataclasses.py deleted file mode 100644 index 6f30b6c3cc..0000000000 --- a/tests/test_base_dataclasses.py +++ /dev/null @@ -1,290 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Tests for AiiDA base data classes.""" -import operator - -from aiida.backends.testbase import AiidaTestCase -from aiida.common.exceptions import ModificationNotAllowed -from aiida.orm import Bool, Float, Int, List, NumericType, Str, load_node -from aiida.orm.nodes.data.bool import get_false_node, get_true_node - - -class TestList(AiidaTestCase): - """Test AiiDA List class.""" - - def test_creation(self): - node = List() - self.assertEqual(len(node), 0) - with self.assertRaises(IndexError): - node[0] # pylint: disable=pointless-statement - - def test_append(self): - """Test append() member function.""" - - def do_checks(node): - self.assertEqual(len(node), 1) - self.assertEqual(node[0], 4) - - node = List() - node.append(4) - do_checks(node) - - # Try the same after storing - node = List() - node.append(4) - node.store() - do_checks(node) - - def test_extend(self): - """Test extend() member function.""" - lst = [1, 2, 3] - - def do_checks(node): - self.assertEqual(len(node), len(lst)) - # Do an element wise comparison - for lst_, node_ in zip(lst, node): - self.assertEqual(lst_, node_) - - node = List() - node.extend(lst) - do_checks(node) - # Further extend - node.extend(lst) - self.assertEqual(len(node), len(lst) * 2) - - # Do an element wise comparison - for i, _ in enumerate(lst): - self.assertEqual(lst[i], node[i]) - self.assertEqual(lst[i], node[i % len(lst)]) - - # Now try after storing - node = List() - node.extend(lst) - node.store() - do_checks(node) - - def test_mutability(self): - """Test list's mutability before and after storage.""" - node = List() - node.append(5) - node.store() - - # Test all mutable calls are now disallowed - with self.assertRaises(ModificationNotAllowed): - node.append(5) - with self.assertRaises(ModificationNotAllowed): - node.extend([5]) - with self.assertRaises(ModificationNotAllowed): - node.insert(0, 2) - with self.assertRaises(ModificationNotAllowed): - node.remove(0) - with self.assertRaises(ModificationNotAllowed): - node.pop() - with self.assertRaises(ModificationNotAllowed): - node.sort() - with self.assertRaises(ModificationNotAllowed): - node.reverse() - - @staticmethod - def test_store_load(): - """Test load_node on just stored object.""" - node = List(list=[1, 2, 3]) - node.store() - - node_loaded = load_node(node.pk) - assert node.get_list() == node_loaded.get_list() - - -class TestFloat(AiidaTestCase): - """Test Float class.""" - - def setUp(self): - super().setUp() - self.value = Float() - self.all_types = [Int, Float, Bool, Str] - - def test_create(self): - """Creating basic data objects.""" - term_a = Float() - # Check that initial value is zero - self.assertAlmostEqual(term_a.value, 0.0) - - float_ = Float(6.0) - self.assertAlmostEqual(float_.value, 6.) - self.assertAlmostEqual(float_, Float(6.0)) - - int_ = Int() - self.assertAlmostEqual(int_.value, 0) - int_ = Int(6) - self.assertAlmostEqual(int_.value, 6) - self.assertAlmostEqual(float_, int_) - - bool_ = Bool() - self.assertAlmostEqual(bool_.value, False) - bool_ = Bool(False) - self.assertAlmostEqual(bool_.value, False) - self.assertAlmostEqual(bool_.value, get_false_node()) - bool_ = Bool(True) - self.assertAlmostEqual(bool_.value, True) - self.assertAlmostEqual(bool_.value, get_true_node()) - - str_ = Str() - self.assertAlmostEqual(str_.value, '') - str_ = Str('Hello') - self.assertAlmostEqual(str_.value, 'Hello') - - def test_load(self): - """Test object loading.""" - for typ in self.all_types: - node = typ() - node.store() - loaded = load_node(node.pk) - self.assertAlmostEqual(node, loaded) - - def test_add(self): - """Test addition.""" - term_a = Float(4) - term_b = Float(5) - # Check adding two db Floats - res = term_a + term_b - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 9.0) - - # Check adding db Float and native (both ways) - res = term_a + 5.0 - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 9.0) - - res = 5.0 + term_a - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 9.0) - - # Inplace - term_a = Float(4) - term_a += term_b - self.assertAlmostEqual(term_a, 9.0) - - term_a = Float(4) - term_a += 5 - self.assertAlmostEqual(term_a, 9.0) - - def test_mul(self): - """Test floats multiplication.""" - term_a = Float(4) - term_b = Float(5) - # Check adding two db Floats - res = term_a * term_b - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 20.0) - - # Check adding db Float and native (both ways) - res = term_a * 5.0 - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 20) - - res = 5.0 * term_a - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 20.0) - - # Inplace - term_a = Float(4) - term_a *= term_b - self.assertAlmostEqual(term_a, 20) - - term_a = Float(4) - term_a *= 5 - self.assertAlmostEqual(term_a, 20) - - def test_power(self): - """Test power operator.""" - term_a = Float(4) - term_b = Float(2) - - res = term_a**term_b - self.assertAlmostEqual(res.value, 16.) - - def test_division(self): - """Test the normal division operator.""" - term_a = Float(3) - term_b = Float(2) - - self.assertAlmostEqual(term_a / term_b, 1.5) - self.assertIsInstance(term_a / term_b, Float) - - def test_division_integer(self): - """Test the integer division operator.""" - term_a = Float(3) - term_b = Float(2) - - self.assertAlmostEqual(term_a // term_b, 1.0) - self.assertIsInstance(term_a // term_b, Float) - - def test_modulus(self): - """Test modulus operator.""" - term_a = Float(12.0) - term_b = Float(10.0) - - self.assertAlmostEqual(term_a % term_b, 2.0) - self.assertIsInstance(term_a % term_b, NumericType) - self.assertAlmostEqual(term_a % 10.0, 2.0) - self.assertIsInstance(term_a % 10.0, NumericType) - self.assertAlmostEqual(12.0 % term_b, 2.0) - self.assertIsInstance(12.0 % term_b, NumericType) - - -class TestFloatIntMix(AiidaTestCase): - """Test operations between Int and Float objects.""" - - def test_operator(self): - """Test all binary operators.""" - term_a = Float(2.2) - term_b = Int(3) - - for oper in [ - operator.add, operator.mul, operator.pow, operator.lt, operator.le, operator.gt, operator.ge, operator.iadd, - operator.imul - ]: - for term_x, term_y in [(term_a, term_b), (term_b, term_a)]: - res = oper(term_x, term_y) - c_val = oper(term_x.value, term_y.value) - self.assertEqual(res._type, type(c_val)) # pylint: disable=protected-access - self.assertEqual(res, oper(term_x.value, term_y.value)) - - -class TestInt(AiidaTestCase): - """Test Int class.""" - - def test_division(self): - """Test the normal division operator.""" - term_a = Int(3) - term_b = Int(2) - - self.assertAlmostEqual(term_a / term_b, 1.5) - self.assertIsInstance(term_a / term_b, Float) - - def test_division_integer(self): - """Test the integer division operator.""" - term_a = Int(3) - term_b = Int(2) - - self.assertAlmostEqual(term_a // term_b, 1) - self.assertIsInstance(term_a // term_b, Int) - - def test_modulo(self): - """Test modulus operation.""" - term_a = Int(12) - term_b = Int(10) - - self.assertEqual(term_a % term_b, 2) - self.assertIsInstance(term_a % term_b, NumericType) - self.assertEqual(term_a % 10, 2) - self.assertIsInstance(term_a % 10, NumericType) - self.assertEqual(12 % term_b, 2) - self.assertIsInstance(12 % term_b, NumericType)