Skip to content

Commit

Permalink
Add types to aiida.orm.utils
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhollas committed Jul 1, 2024
1 parent 24cfbe2 commit 8ce4773
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 24 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ repos:
src/aiida/orm/utils/builders/code.py|
src/aiida/orm/utils/builders/computer.py|
src/aiida/orm/utils/calcjob.py|
src/aiida/orm/utils/node.py|
src/aiida/orm/utils/remote.py|
src/aiida/repository/backend/disk_object_store.py|
src/aiida/repository/backend/sandbox.py|
src/aiida/restapi/common/utils.py|
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/orm/nodes/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
###########################################################################
"""`Data` sub class to be used as a base for data containers that represent base python data types."""

import typing as t
from functools import singledispatch

from aiida.orm.fields import add_field
Expand All @@ -18,7 +19,7 @@


@singledispatch
def to_aiida_type(value):
def to_aiida_type(value: t.Any) -> Data:
"""Turns basic Python types (str, int, float, bool) into the corresponding AiiDA types."""
raise TypeError(f'Cannot convert value of type {type(value)} to AiiDA type.')

Expand Down
2 changes: 1 addition & 1 deletion src/aiida/orm/nodes/data/str.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ class Str(BaseType):


@to_aiida_type.register(str)
def _(value):
def _(value: str) -> Str:
return Str(value)
14 changes: 7 additions & 7 deletions src/aiida/orm/utils/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from aiida.common import exceptions
from aiida.common.utils import strip_prefix
from aiida.orm import Data, Node
from aiida.orm.fields import EntityFieldMeta

__all__ = (
Expand All @@ -23,13 +24,12 @@
)


def load_node_class(type_string):
def load_node_class(type_string: str) -> type[Node]:
"""Return the `Node` sub class that corresponds to the given type string.
:param type_string: the `type` string of the node
:return: a sub class of `Node`
"""
from aiida.orm import Data, Node
from aiida.plugins.entry_point import load_entry_point

if type_string == '':
Expand Down Expand Up @@ -74,7 +74,7 @@ def load_node_class(type_string):
return Data


def get_type_string_from_class(class_module, class_name):
def get_type_string_from_class(class_module: str, class_name: str) -> str:
"""Given the module and name of a class, determine the orm_class_type string, which codifies the
orm class that is to be used. The returned string will always have a terminating period, which
is required to query for the string in the database
Expand Down Expand Up @@ -110,11 +110,11 @@ def get_type_string_from_class(class_module, class_name):
return type_string


def is_valid_node_type_string(type_string, raise_on_false=False):
def is_valid_node_type_string(type_string: str, raise_on_false: bool = False) -> bool:
"""Checks whether type string of a Node is valid.
:param type_string: the plugin_type_string attribute of a Node
:return: True if type string is valid, else false
:return: True if type string is valid, else False
"""
# Currently the type string for the top-level node is empty.
# Change this when a consistent type string hierarchy is introduced.
Expand All @@ -131,7 +131,7 @@ def is_valid_node_type_string(type_string, raise_on_false=False):
return True


def get_query_type_from_type_string(type_string):
def get_query_type_from_type_string(type_string: str) -> str:
"""Take the type string of a Node and create the queryable type string
:param type_string: the plugin_type_string attribute of a Node
Expand All @@ -155,5 +155,5 @@ class AbstractNodeMeta(EntityFieldMeta):

def __new__(mcs, name, bases, namespace, **kwargs): # noqa: N804
newcls = super().__new__(mcs, name, bases, namespace, **kwargs)
newcls._logger = logging.getLogger(f"{namespace['__module__']}.{name}")
newcls._logger = logging.getLogger(f"{namespace['__module__']}.{name}") # type: ignore[attr-defined]
return newcls
34 changes: 21 additions & 13 deletions src/aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
"""Utilities for operations on files on remote computers."""

import os
import typing as t

from aiida.orm.nodes.data.remote.base import RemoteData

if t.TYPE_CHECKING:
from collections.abc import Sequence

def clean_remote(transport, path):
from aiida import orm
from aiida.orm.implementation import StorageBackend
from aiida.transports import Transport


def clean_remote(transport: Transport, path: str) -> None:
"""Recursively remove a remote folder, with the given absolute path, and all its contents. The path should be
made accessible through the transport channel, which should already be open
Expand All @@ -39,15 +47,15 @@ def clean_remote(transport, path):


def get_calcjob_remote_paths(
pks=None,
past_days=None,
older_than=None,
computers=None,
user=None,
backend=None,
exit_status=None,
only_not_cleaned=False,
):
pks: list[int] | None = None,
past_days: int | None = None,
older_than: int | None = None,
computers: Sequence[orm.Computer] | None = None,
user: orm.User | None = None,
backend: StorageBackend | None = None,
exit_status: int | None = None,
only_not_cleaned: bool = False,
) -> dict[str, list[RemoteData]] | None:
"""Return a mapping of computer uuids to a list of remote paths, for a given set of calcjobs. The set of
calcjobs will be determined by a query with filters based on the pks, past_days, older_than,
computers and user arguments.
Expand All @@ -67,7 +75,7 @@ def get_calcjob_remote_paths(
from aiida.common import timezone
from aiida.orm import CalcJobNode

filters_calc = {}
filters_calc: dict[str, t.Any] = {}
filters_computer = {}
filters_remote = {}

Expand Down Expand Up @@ -110,12 +118,12 @@ def get_calcjob_remote_paths(
RemoteData, tag='remote', project=['*'], edge_filters={'label': 'remote_folder'}, filters=filters_remote
)
query.append(orm.Computer, with_node='calc', tag='computer', project=['uuid'], filters=filters_computer)
query.append(orm.User, with_node='calc', filters={'email': user.email})
query.append(orm.User, with_node='calc', filters={'email': user.email}) # type: ignore[union-attr]

if query.count() == 0:
return None

path_mapping = {}
path_mapping: dict[str, list[RemoteData]] = {}

for remote_data, computer_uuid in query.iterall():
path_mapping.setdefault(computer_uuid, []).append(remote_data)
Expand Down

0 comments on commit 8ce4773

Please sign in to comment.