Skip to content

Commit

Permalink
πŸ‘Œ IMPROVE: constructor of base data types (#5165)
Browse files Browse the repository at this point in the history
* πŸ‘Œ IMPROVE: constructor of base data types

Adapt the constructor of the `Dict` and `List` data types so the
value no longer needs to be provided as a keyword argument, but can
simply be passed as the first input.

Remove the `*args` from the `BaseType` constructor, instead specifying
the `value` input argument. Otherwise users could just pass multiple
positional arguments and it would only use the first without raising an
error.

Also fixes two issues with the `List` class:

* The `remove` method was not working as prescribed. Instead of
removing the first element in the list with the specified value, it
simply deleted the element with index `value`.
* The `set_list()` method didn't make a copy of the `list` input, so any
modification done to the `List` instance would also affect the original
`list`. If the same list is used to initialise several `List` instances,
adapting one would affect the other.

Finally, add tests for the `List` class.

* Refactor data tests
  • Loading branch information
mbercx authored Dec 6, 2021
1 parent 08ac107 commit eb73c7d
Show file tree
Hide file tree
Showing 27 changed files with 459 additions and 316 deletions.
9 changes: 2 additions & 7 deletions aiida/orm/nodes/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,15 @@ def to_aiida_type(value):
class BaseType(Data):
"""`Data` sub class to be used as a base for data containers that represent base python data types."""

def __init__(self, *args, **kwargs):
def __init__(self, value=None, **kwargs):
try:
getattr(self, '_type')
except AttributeError:
raise RuntimeError('Derived class must define the `_type` class member')

super().__init__(**kwargs)

try:
value = args[0]
except IndexError:
value = self._type() # pylint: disable=no-member

self.value = value
self.value = value or self._type() # pylint: disable=no-member

@property
def value(self):
Expand Down
12 changes: 6 additions & 6 deletions aiida/orm/nodes/data/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ class Dict(Data):
Finally, all dictionary mutations will be forbidden once the node is stored.
"""

def __init__(self, **kwargs):
"""Store a dictionary as a `Node` instance.
def __init__(self, value=None, **kwargs):
"""Initialise a ``Dict`` node instance.
Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a `ValueError`
Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a ``ValueError``
will be raised.
Initial attributes can be changed, deleted or added as long as the node is not stored.
:param dict: the dictionary to set
:param value: dictionary to initialise the ``Dict`` node from
"""
dictionary = kwargs.pop('dict', None)
dictionary = value or kwargs.pop('dict', None)
super().__init__(**kwargs)
if dictionary:
self.set_dict(dictionary)
Expand Down Expand Up @@ -135,4 +135,4 @@ def dict(self):

@to_aiida_type.register(dict)
def _(value):
return Dict(dict=value)
return Dict(value)
16 changes: 12 additions & 4 deletions aiida/orm/nodes/data/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ class List(Data, MutableSequence):

_LIST_KEY = 'list'

def __init__(self, **kwargs):
data = kwargs.pop('list', [])
def __init__(self, value=None, **kwargs):
"""Initialise a ``List`` node instance.
:param value: list to initialise the ``List`` node from
"""
data = value or kwargs.pop('list', [])
super().__init__(**kwargs)
self.set_list(data)

Expand Down Expand Up @@ -75,7 +79,11 @@ def insert(self, i, value): # pylint: disable=arguments-renamed
self.set_list(data)

def remove(self, value):
del self[value]
data = self.get_list()
item = data.remove(value)
if not self._using_list_reference():
self.set_list(data)
return item

def pop(self, **kwargs): # pylint: disable=arguments-differ
"""Remove and return item at index (default last)."""
Expand Down Expand Up @@ -123,7 +131,7 @@ def set_list(self, data):
"""
if not isinstance(data, list):
raise TypeError('Must supply list type')
self.set_attribute(self._LIST_KEY, data)
self.set_attribute(self._LIST_KEY, data.copy())

def _using_list_reference(self):
"""
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
207 changes: 207 additions & 0 deletions tests/orm/nodes/data/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=invalid-name
"""Tests for :class:`aiida.orm.nodes.data.base.BaseType` classes."""

import operator

import pytest

from aiida.orm import Bool, Float, Int, NumericType, Str, load_node


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize(
'node_type, default, value', [
(Bool, False, True),
(Int, 0, 5),
(Float, 0.0, 5.5),
(Str, '', 'a'),
]
)
def test_create(node_type, default, value):
"""Test the creation of the ``BaseType`` nodes."""

node = node_type()
assert node.value == default

node = node_type(value)
assert node.value == value


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type', [Bool, Float, Int, Str])
def test_store_load(node_type):
"""Test ``BaseType`` node storing and loading."""
node = node_type()
node.store()
loaded = load_node(node.pk)
assert node.value == loaded.value


@pytest.mark.usefixtures('clear_database_before_test')
def test_modulo():
"""Test ``Int`` modulus operation."""
term_a = Int(12)
term_b = Int(10)

assert term_a % term_b == 2
assert isinstance(term_a % term_b, NumericType)
assert term_a % 10 == 2
assert isinstance(term_a % 10, NumericType)
assert 12 % term_b == 2
assert isinstance(12 % term_b, NumericType)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
def test_add(node_type, a, b):
"""Test addition for ``Int`` and ``Float`` nodes."""
node_a = node_type(a)
node_b = node_type(b)

result = node_a + node_b
assert isinstance(result, node_type)
assert result.value == a + b

# Node and native (both ways)
result = node_a + b
assert isinstance(result, node_type)
assert result.value == a + b

result = a + node_b
assert isinstance(result, node_type)
assert result.value == a + b

# Inplace
result = node_type(a)
result += node_b
assert isinstance(result, node_type)
assert result.value == a + b


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
def test_multiplication(node_type, a, b):
"""Test floats multiplication."""
node_a = node_type(a)
node_b = node_type(b)

# Check multiplication
result = node_a * node_b
assert isinstance(result, node_type)
assert result.value == a * b

# Check multiplication Node and native (both ways)
result = node_a * b
assert isinstance(result, node_type)
assert result.value == a * b

result = a * node_b
assert isinstance(result, node_type)
assert result.value == a * b

# Inplace
result = node_type(a)
result *= node_b
assert isinstance(result, node_type)
assert result.value == a * b


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
@pytest.mark.usefixtures('clear_database_before_test')
def test_division(node_type, a, b):
"""Test the ``BaseType`` normal division operator."""
node_a = node_type(a)
node_b = node_type(b)

result = node_a / node_b
assert result == a / b
assert isinstance(result, Float) # Should be a `Float` for both node types


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
@pytest.mark.usefixtures('clear_database_before_test')
def test_division_integer(node_type, a, b):
"""Test the ``Int`` integer division operator."""
node_a = node_type(a)
node_b = node_type(b)

result = node_a // node_b
assert result == a // b
assert isinstance(result, node_type)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, base, power', [
(Int, 5, 2),
(Float, 3.5, 3),
])
def test_power(node_type, base, power):
"""Test power operator."""
node_base = node_type(base)
node_power = node_type(power)

result = node_base**node_power
assert result == base**power
assert isinstance(result, node_type)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 5, 2),
(Float, 3.5, 3),
])
def test_modulus(node_type, a, b):
"""Test modulus operator."""
node_a = node_type(a)
node_b = node_type(b)

assert node_a % node_b == a % b
assert isinstance(node_a % node_b, node_type)

assert node_a % b == a % b
assert isinstance(node_a % b, node_type)

assert a % node_b == a % b
assert isinstance(a % node_b, node_type)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize(
'opera', [
operator.add, operator.mul, operator.pow, operator.lt, operator.le, operator.gt, operator.ge, operator.iadd,
operator.imul
]
)
def test_operator(opera):
"""Test operations between Int and Float objects."""
node_a = Float(2.2)
node_b = Int(3)

for node_x, node_y in [(node_a, node_b), (node_b, node_a)]:
res = opera(node_x, node_y)
c_val = opera(node_x.value, node_y.value)
assert res._type == type(c_val) # pylint: disable=protected-access
assert res == opera(node_x.value, node_y.value)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,29 @@ def dictionary():
@pytest.mark.usefixtures('clear_database_before_test')
def test_keys(dictionary):
"""Test the ``keys`` method."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert sorted(node.keys()) == sorted(dictionary.keys())


@pytest.mark.usefixtures('clear_database_before_test')
def test_get_dict(dictionary):
"""Test the ``get_dict`` method."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert node.get_dict() == dictionary


@pytest.mark.usefixtures('clear_database_before_test')
def test_dict_property(dictionary):
"""Test the ``dict`` property."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert node.dict.value == dictionary['value']
assert node.dict.nested == dictionary['nested']


@pytest.mark.usefixtures('clear_database_before_test')
def test_get_item(dictionary):
"""Test the ``__getitem__`` method."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert node['value'] == dictionary['value']
assert node['nested'] == dictionary['nested']

Expand All @@ -56,7 +56,7 @@ def test_set_item(dictionary):
* ``__setitem__`` directly on the node
* ``__setattr__`` through the ``AttributeManager`` returned by the ``dict`` property
"""
node = Dict(dict=dictionary)
node = Dict(dictionary)

node['value'] = 2
assert node['value'] == 2
Expand All @@ -72,7 +72,7 @@ def test_correct_raises(dictionary):
* ``node['inexistent']`` should raise ``KeyError``
* ``node.dict.inexistent`` should raise ``AttributeError``
"""
node = Dict(dict=dictionary)
node = Dict(dictionary)

with pytest.raises(KeyError):
_ = node['inexistent_key']
Expand All @@ -89,8 +89,8 @@ def test_eq(dictionary):
compare equal to another node that has the same content. This is a hot issue and is being discussed in the following
ticket: https://github.com/aiidateam/aiida-core/issues/1917
"""
node = Dict(dict=dictionary)
clone = Dict(dict=dictionary)
node = Dict(dictionary)
clone = Dict(dictionary)

assert node is node # pylint: disable=comparison-with-itself
assert node == dictionary
Expand All @@ -101,8 +101,15 @@ def test_eq(dictionary):
# wouldn't happen unless, by accident, two different nodes get the same UUID, the probability of which is minimal.
# Note that we have to set the UUID directly through the database model instance of the backend entity, since it is
# forbidden to change it through the front-end or backend entity instance, for good reasons.
other = Dict(dict={})
other = Dict({})
other.backend_entity._dbmodel.uuid = node.uuid # pylint: disable=protected-access
assert other.uuid == node.uuid
assert other.dict != node.dict
assert node == other


@pytest.mark.usefixtures('clear_database_before_test')
def test_initialise_with_dict_kwarg(dictionary):
"""Test that the ``Dict`` node can be initialized with the ``dict`` keyword argument for backwards compatibility."""
node = Dict(dict=dictionary)
assert sorted(node.keys()) == sorted(dictionary.keys())
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit eb73c7d

Please sign in to comment.