diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py new file mode 100644 index 00000000..0163a457 --- /dev/null +++ b/dptb/data/AtomicData.py @@ -0,0 +1,993 @@ +"""AtomicData: neighbor graphs in (periodic) real space. + +Authors: Albert Musaelian +""" + +import warnings +from copy import deepcopy +from typing import Union, Tuple, Dict, Optional, List, Set, Sequence +from collections.abc import Mapping +import os + +import numpy as np +import ase.neighborlist +import ase +from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator +from ase.calculators.calculator import all_properties as ase_all_properties +from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress + +import torch +import e3nn.o3 + +from . import AtomicDataDict +from .util import _TORCH_INTEGER_DTYPES +from dptb.utils.torch_geometric.data import Data + +# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) +PBC = Union[bool, Tuple[bool, bool, bool]] + + +_DEFAULT_LONG_FIELDS: Set[str] = { + AtomicDataDict.EDGE_INDEX_KEY, + AtomicDataDict.ENV_INDEX_KEY, # new + AtomicDataDict.ONSITENV_INDEX_KEY, # new + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.BATCH_KEY, +} + +_DEFAULT_NODE_FIELDS: Set[str] = { + AtomicDataDict.POSITIONS_KEY, + AtomicDataDict.NODE_FEATURES_KEY, + AtomicDataDict.NODE_ATTRS_KEY, + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.FORCE_KEY, + AtomicDataDict.PER_ATOM_ENERGY_KEY, + AtomicDataDict.NODE_HAMILTONIAN_KEY, + AtomicDataDict.NODE_OVERLAP_KEY, + AtomicDataDict.BATCH_KEY, +} + +_DEFAULT_EDGE_FIELDS: Set[str] = { + AtomicDataDict.EDGE_CELL_SHIFT_KEY, + AtomicDataDict.EDGE_VECTORS_KEY, + AtomicDataDict.EDGE_LENGTH_KEY, + AtomicDataDict.EDGE_ATTRS_KEY, + AtomicDataDict.EDGE_EMBEDDING_KEY, + AtomicDataDict.EDGE_FEATURES_KEY, + AtomicDataDict.EDGE_CUTOFF_KEY, + AtomicDataDict.EDGE_ENERGY_KEY, + AtomicDataDict.EDGE_OVERLAP_KEY, + AtomicDataDict.EDGE_HAMILTONIAN_KEY, + AtomicDataDict.EDGE_TYPE_KEY, +} + +_DEFAULT_ENV_FIELDS: Set[str] = { + AtomicDataDict.ENV_CELL_SHIFT_KEY, + AtomicDataDict.ENV_VECTORS_KEY, + AtomicDataDict.ENV_LENGTH_KEY, + AtomicDataDict.ENV_ATTRS_KEY, + AtomicDataDict.ENV_EMBEDDING_KEY, + AtomicDataDict.ENV_FEATURES_KEY, + AtomicDataDict.ENV_CUTOFF_KEY, +} + +_DEFAULT_ONSITENV_FIELDS: Set[str] = { + AtomicDataDict.ONSITENV_CELL_SHIFT_KEY, + AtomicDataDict.ONSITENV_VECTORS_KEY, + AtomicDataDict.ONSITENV_LENGTH_KEY, + AtomicDataDict.ONSITENV_ATTRS_KEY, + AtomicDataDict.ONSITENV_EMBEDDING_KEY, + AtomicDataDict.ONSITENV_FEATURES_KEY, + AtomicDataDict.ONSITENV_CUTOFF_KEY, +} + +_DEFAULT_GRAPH_FIELDS: Set[str] = { + AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.STRESS_KEY, + AtomicDataDict.VIRIAL_KEY, + AtomicDataDict.PBC_KEY, + AtomicDataDict.CELL_KEY, + AtomicDataDict.BATCH_PTR_KEY, + AtomicDataDict.KPOINT_KEY, # new + AtomicDataDict.HAMILTONIAN_KEY, # new + AtomicDataDict.OVERLAP_KEY, # new + AtomicDataDict.ENERGY_EIGENVALUE_KEY, # new + AtomicDataDict.ENERGY_WINDOWS_KEY, # new, + AtomicDataDict.BAND_WINDOW_KEY # new, +} + +_NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) +_EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) +_ENV_FIELDS: Set[str] = set(_DEFAULT_ENV_FIELDS) +_ONSITENV_FIELDS: Set[str] = set(_DEFAULT_ONSITENV_FIELDS) +_GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) +_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS) + + +def register_fields( + node_fields: Sequence[str] = [], + edge_fields: Sequence[str] = [], + env_fields: Sequence[str] = [], + onsitenv_fields: Sequence[str] = [], + graph_fields: Sequence[str] = [], + long_fields: Sequence[str] = [], +) -> None: + + r"""Register fields as being per-atom, per-edge, or per-frame. + + Args: + node_permute_fields: fields that are equivariant to node permutations. + edge_permute_fields: fields that are equivariant to edge permutations. + """ + + node_fields: set = set(node_fields) + edge_fields: set = set(edge_fields) + env_fields: set = set(env_fields) + onsitenv_fields: set = set(onsitenv_fields) + graph_fields: set = set(graph_fields) + long_fields: set = set(long_fields) + allfields = node_fields.union(edge_fields, graph_fields, env_fields, onsitenv_fields) + assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) + _NODE_FIELDS.update(node_fields) + _EDGE_FIELDS.update(edge_fields) + _ENV_FIELDS.update(env_fields) + _ONSITENV_FIELDS.update(onsitenv_fields) + _GRAPH_FIELDS.update(graph_fields) + _LONG_FIELDS.update(long_fields) + if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( + len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) + ): + raise ValueError( + "At least one key was registered as more than one of node, edge, or graph!" + ) + + +def deregister_fields(*fields: Sequence[str]) -> None: + r"""Deregister a field registered with ``register_fields``. + + Silently ignores fields that were never registered to begin with. + + Args: + *fields: fields to deregister. + """ + for f in fields: + assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_ENV_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_ONSITENV_FIELDS, "Cannot deregister built-in field" + _NODE_FIELDS.discard(f) + _EDGE_FIELDS.discard(f) + _ENV_FIELDS.discard(f) + _ONSITENV_FIELDS.discard(f) + _GRAPH_FIELDS.discard(f) + + +def _register_field_prefix(prefix: str) -> None: + """Re-register all registered fields as the same type, but with `prefix` added on.""" + assert prefix.endswith("_") + register_fields( + node_fields=[prefix + e for e in _NODE_FIELDS], + edge_fields=[prefix + e for e in _EDGE_FIELDS], + env_fields=[prefix + e for e in _ENV_FIELDS], + onsitenv_fields=[prefix + e for e in _ONSITENV_FIELDS], + graph_fields=[prefix + e for e in _GRAPH_FIELDS], + long_fields=[prefix + e for e in _LONG_FIELDS], + ) + + +def _process_dict(kwargs, ignore_fields=[]): + """Convert a dict of data into correct dtypes/shapes according to key""" + # Deal with _some_ dtype issues + for k, v in kwargs.items(): + if k in ignore_fields: + continue + + if k in _LONG_FIELDS: + # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) + # int32 would pass later checks, but is actually disallowed by torch + kwargs[k] = torch.as_tensor(v, dtype=torch.long) + elif isinstance(v, bool): + kwargs[k] = torch.as_tensor(v) + elif isinstance(v, np.ndarray): + if np.issubdtype(v.dtype, np.floating): + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + else: + kwargs[k] = torch.as_tensor(v) + elif isinstance(v, list): + ele_dtype = np.array(v).dtype + if np.issubdtype(ele_dtype, np.floating): + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + else: + kwargs[k] = torch.as_tensor(v) + elif np.issubdtype(type(v), np.floating): + # Force scalars to be tensors with a data dimension + # This makes them play well with irreps + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + elif isinstance(v, torch.Tensor) and len(v.shape) == 0: + # ^ this tensor is a scalar; we need to give it + # a data dimension to play nice with irreps + kwargs[k] = v + + if AtomicDataDict.BATCH_KEY in kwargs: + num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1 + else: + num_frames = 1 + + for k, v in kwargs.items(): + if k in ignore_fields: + continue + + if len(v.shape) == 0: + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if k in set.union(_NODE_FIELDS, _EDGE_FIELDS) and len(v.shape) == 1: + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if ( + k in _NODE_FIELDS + and AtomicDataDict.POSITIONS_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0] + ): + raise ValueError( + f"{k} is a node field but has the wrong dimension {v.shape}" + ) + elif ( + k in _EDGE_FIELDS + and AtomicDataDict.EDGE_INDEX_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.EDGE_INDEX_KEY].shape[1] + ): + raise ValueError( + f"{k} is a edge field but has the wrong dimension {v.shape}" + ) + elif ( + k in _ENV_FIELDS + and AtomicDataDict.ENV_INDEX_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.ENV_INDEX_KEY].shape[1] + ): + raise ValueError( + f"{k} is a env field but has the wrong dimension {v.shape}" + ) + elif ( + k in _ONSITENV_FIELDS + and AtomicDataDict.ONSITENV_INDEX_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.ONSITENV_INDEX_KEY].shape[1] + ): + raise ValueError( + f"{k} is a env field but has the wrong dimension {v.shape}" + ) + elif k in _GRAPH_FIELDS: + if num_frames > 1 and v.shape[0] != num_frames: + raise ValueError(f"Wrong shape for graph property {k}") + + +class AtomicData(Data): + """A neighbor graph for points in (periodic triclinic) real space. + + For typical cases either ``from_points`` or ``from_ase`` should be used to + construct a AtomicData; they also standardize and check their input much more + thoroughly. + + In general, ``node_features`` are features or input information on the nodes that will be fed through and transformed by the network, while ``node_attrs`` are _encodings_ fixed, inherant attributes of the atoms themselves that remain constant through the network. + For example, a one-hot _encoding_ of atomic species is a node attribute, while some observed instantaneous property of that atom (current partial charge, for example), would be a feature. + + In general, ``torch.Tensor`` arguments should be of consistant dtype. Numpy arrays will be converted to ``torch.Tensor``s; those of floating point dtype will be converted to ``torch.get_current_dtype()`` regardless of their original precision. Scalar values (Python scalars or ``torch.Tensor``s of shape ``()``) a resized to tensors of shape ``[1]``. Per-atom scalar values should be given with shape ``[N_at, 1]``. + + ``AtomicData`` should be used for all data creation and manipulation outside of the model; inside of the model ``AtomicDataDict.Type`` is used. + + Args: + pos (Tensor [n_nodes, 3]): Positions of the nodes. + edge_index (LongTensor [2, n_edges]): ``edge_index[0]`` is the per-edge + index of the source node and ``edge_index[1]`` is the target node. + edge_cell_shift (Tensor [n_edges, 3], optional): which periodic image + of the target point each edge goes to, relative to the source point. + cell (Tensor [1, 3, 3], optional): the periodic cell for + ``edge_cell_shift`` as the three triclinic cell vectors. + node_features (Tensor [n_atom, ...]): the input features of the nodes, optional + node_attrs (Tensor [n_atom, ...]): the attributes of the nodes, for instance the atom type, optional + batch (Tensor [n_atom]): the graph to which the node belongs, optional + atomic_numbers (Tensor [n_atom]): optional. + atom_type (Tensor [n_atom]): optional. + **kwargs: other data, optional. + """ + + def __init__( + self, irreps: Dict[str, e3nn.o3.Irreps] = {}, _validate: bool = True, **kwargs + ): + + # empty init needed by get_example + if len(kwargs) == 0 and len(irreps) == 0: + super().__init__() + return + + # Check the keys + if _validate: + AtomicDataDict.validate_keys(kwargs) + _process_dict(kwargs) + + super().__init__(num_nodes=len(kwargs["pos"]), **kwargs) + + if _validate: + # Validate shapes + assert self.pos.dim() == 2 and self.pos.shape[1] == 3 + assert self.edge_index.dim() == 2 and self.edge_index.shape[0] == 2 + if "edge_cell_shift" in self and self.edge_cell_shift is not None: + assert self.edge_cell_shift.shape == (self.num_edges, 3) + assert self.edge_cell_shift.dtype == self.pos.dtype + # TODO: should we add checks for env too ? + if "cell" in self and self.cell is not None: + assert (self.cell.shape == (3, 3)) or ( + self.cell.dim() == 3 and self.cell.shape[1:] == (3, 3) + ) + assert self.cell.dtype == self.pos.dtype + if "node_features" in self and self.node_features is not None: + assert self.node_features.shape[0] == self.num_nodes + assert self.node_features.dtype == self.pos.dtype + if "node_attrs" in self and self.node_attrs is not None: + assert self.node_attrs.shape[0] == self.num_nodes + assert self.node_attrs.dtype == self.pos.dtype + + if ( + AtomicDataDict.ATOMIC_NUMBERS_KEY in self + and self.atomic_numbers is not None + ): + assert self.atomic_numbers.dtype in _TORCH_INTEGER_DTYPES + if "batch" in self and self.batch is not None: + assert self.batch.dim() == 2 and self.batch.shape[0] == self.num_nodes + # Check that there are the right number of cells + if "cell" in self and self.cell is not None: + cell = self.cell.view(-1, 3, 3) + assert cell.shape[0] == self.batch.max() + 1 + + # Validate irreps + # __*__ is the only way to hide from torch_geometric + self.__irreps__ = AtomicDataDict._fix_irreps_dict(irreps) + for field, irreps in self.__irreps__: + if irreps is not None: + assert self[field].shape[-1] == irreps.dim + + @classmethod + def from_points( + cls, + pos=None, + r_max: float = None, + self_interaction: bool = False, + cell=None, + pbc: Optional[PBC] = None, + er_max: Optional[float] = None, + oer_max: Optional[float] = None, + **kwargs, + ): + """Build neighbor graph from points, optionally with PBC. + + Args: + pos (np.ndarray/torch.Tensor shape [N, 3]): node positions. If Tensor, must be on the CPU. + r_max (float): neighbor cutoff radius. + cell (ase.Cell/ndarray [3,3], optional): periodic cell for the points. Defaults to ``None``. + pbc (bool or 3-tuple of bool, optional): whether to apply periodic boundary conditions to all or each of + the three cell vector directions. Defaults to ``False``. + self_interaction (bool, optional): whether to include self edges for points. Defaults to ``False``. Note + that edges between the same atom in different periodic images are still included. (See + ``strict_self_interaction`` to control this behaviour.) + strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the + two instances of the atom are in different periodic images. Defaults to True, should be True for most + applications. + **kwargs (optional): other fields to add. Keys listed in ``AtomicDataDict.*_KEY` will be treated specially. + """ + if pos is None or r_max is None: + raise ValueError("pos and r_max must be given.") + + if pbc is None: + if cell is not None: + raise ValueError( + "A cell was provided, but pbc weren't. Please explicitly probide PBC." + ) + # there are no PBC if cell and pbc are not provided + pbc = False + + if isinstance(pbc, bool): + pbc = (pbc,) * 3 + else: + assert len(pbc) == 3 + + # TODO: We can only compute the edge vector one times with the largest radial distance among [r_max, er_max, oer_max] + + pos = torch.as_tensor(pos, dtype=torch.get_default_dtype()) + + edge_index, edge_cell_shift, cell = neighbor_list_and_relative_vec( + pos=pos, + r_max=r_max, + self_interaction=self_interaction, + cell=cell, + reduce=False, + atomic_numbers=kwargs.get("atomic_numbers", None), + pbc=pbc, + ) + + # Make torch tensors for data: + if cell is not None: + kwargs[AtomicDataDict.CELL_KEY] = cell.view(3, 3) + kwargs[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = edge_cell_shift + if pbc is not None: + kwargs[AtomicDataDict.PBC_KEY] = torch.as_tensor( + pbc, dtype=torch.bool + ).view(3) + + # add env index + if er_max is not None: + env_index, env_cell_shift, _ = neighbor_list_and_relative_vec( + pos=pos, + r_max=er_max, + self_interaction=self_interaction, + cell=cell, + reduce=False, + atomic_numbers=kwargs.get("atomic_numbers", None), + pbc=pbc, + ) + + if cell is not None: + kwargs[AtomicDataDict.ENV_CELL_SHIFT_KEY] = env_cell_shift + kwargs[AtomicDataDict.ENV_INDEX_KEY] = env_index + + # add onsitenv index + if oer_max is not None: + onsitenv_index, onsitenv_cell_shift, _ = neighbor_list_and_relative_vec( + pos=pos, + r_max=oer_max, + self_interaction=self_interaction, + cell=cell, + reduce=False, + atomic_numbers=kwargs.get("atomic_numbers", None), + pbc=pbc + ) + + if cell is not None: + kwargs[AtomicDataDict.ONSITENV_CELL_SHIFT_KEY] = onsitenv_cell_shift + kwargs[AtomicDataDict.ONSITENV_INDEX_KEY] = onsitenv_index + + return cls(edge_index=edge_index, pos=torch.as_tensor(pos), **kwargs) + + @classmethod + def from_ase( + cls, + atoms, + r_max, + er_max: Optional[float] = None, + oer_max: Optional[float] = None, + key_mapping: Optional[Dict[str, str]] = {}, + include_keys: Optional[list] = [], + **kwargs, + ): + """Build a ``AtomicData`` from an ``ase.Atoms`` object. + + Respects ``atoms``'s ``pbc`` and ``cell``. + + First tries to extract energies and forces from a single-point calculator associated with the ``Atoms`` if one is present and has those fields. + If either is not found, the method will look for ``energy``/``energies`` and ``force``/``forces`` in ``atoms.arrays``. + + `get_atomic_numbers()` will be stored as the atomic_numbers attribute. + + Args: + atoms (ase.Atoms): the input. + r_max (float): neighbor cutoff radius. + features (torch.Tensor shape [N, M], optional): per-atom M-dimensional feature vectors. If ``None`` (the + default), uses a one-hot encoding of the species present in ``atoms``. + include_keys (list): list of additional keys to include in AtomicData aside from the ones defined in + ase.calculators.calculator.all_properties. Optional + key_mapping (dict): rename ase property name to a new string name. Optional + **kwargs (optional): other arguments for the ``AtomicData`` constructor. + + Returns: + A ``AtomicData``. + """ + # from nequip.ase import NequIPCalculator + + assert "pos" not in kwargs + + default_args = set( + [ + "numbers", + "positions", + ] # ase internal names for position and atomic_numbers + + ["pbc", "cell", "pos", "r_max", "er_max", "oer_max"] # arguments for from_points method + + list(kwargs.keys()) + ) + # the keys that are duplicated in kwargs are removed from the include_keys + include_keys = list( + set(include_keys + ase_all_properties + list(key_mapping.keys())) + - default_args + ) + + km = { + "forces": AtomicDataDict.FORCE_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + } + km.update(key_mapping) + key_mapping = km + + add_fields = {} + + # Get info from atoms.arrays; lowest priority. copy first + add_fields = { + key_mapping.get(k, k): v + for k, v in atoms.arrays.items() + if k in include_keys + } + + # Get info from atoms.info; second lowest priority. + add_fields.update( + { + key_mapping.get(k, k): v + for k, v in atoms.info.items() + if k in include_keys + } + ) + + # if atoms.calc is not None: + + # if isinstance( + # atoms.calc, (SinglePointCalculator, SinglePointDFTCalculator) + # ): + # add_fields.update( + # { + # key_mapping.get(k, k): deepcopy(v) + # for k, v in atoms.calc.results.items() + # if k in include_keys + # } + # ) + # elif isinstance(atoms.calc, NequIPCalculator): + # pass # otherwise the calculator breaks + # else: + # raise NotImplementedError( + # f"`from_ase` does not support calculator {atoms.calc}" + # ) + + add_fields[AtomicDataDict.ATOMIC_NUMBERS_KEY] = atoms.get_atomic_numbers() + + # cell and pbc in kwargs can override the ones stored in atoms + cell = kwargs.pop("cell", atoms.get_cell()) + pbc = kwargs.pop("pbc", atoms.pbc) + + # handle ASE-style 6 element Voigt order stress + for key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY): + if key in add_fields: + if add_fields[key].shape == (3, 3): + # it's already 3x3, do nothing else + pass + elif add_fields[key].shape == (6,): + # it's Voigt order + add_fields[key] = voigt_6_to_full_3x3_stress(add_fields[key]) + else: + raise RuntimeError(f"bad shape for {key}") + + return cls.from_points( + pos=atoms.positions, + r_max=r_max, + er_max=er_max, + oer_max=oer_max, + cell=cell, + pbc=pbc, + **kwargs, + **add_fields, + ) + + def to_ase( + self, + type_mapper=None, + extra_fields: List[str] = [], + ) -> Union[List[ase.Atoms], ase.Atoms]: + """Build a (list of) ``ase.Atoms`` object(s) from an ``AtomicData`` object. + + For each unique batch number provided in ``AtomicDataDict.BATCH_KEY``, + an ``ase.Atoms`` object is created. If ``AtomicDataDict.BATCH_KEY`` does not + exist in self, a single ``ase.Atoms`` object is created. + + Args: + type_mapper: if provided, will be used to map ``ATOM_TYPES`` back into + elements, if the configuration of the ``type_mapper`` allows. + extra_fields: fields other than those handled explicitly (currently + those defining the structure as well as energy, per-atom energy, + and forces) to include in the output object. Per-atom (per-node) + quantities will be included in ``arrays``; per-graph and per-edge + quantities will be included in ``info``. + + Returns: + A list of ``ase.Atoms`` objects if ``AtomicDataDict.BATCH_KEY`` is in self + and is not None. Otherwise, a single ``ase.Atoms`` object is returned. + """ + positions = self.pos + edge_index = self[AtomicDataDict.EDGE_INDEX_KEY] + if positions.device != torch.device("cpu"): + raise TypeError( + "Explicitly move this `AtomicData` to CPU using `.to()` before calling `to_ase()`." + ) + if AtomicDataDict.ATOMIC_NUMBERS_KEY in self: + atomic_nums = self.atomic_numbers + elif type_mapper is not None and type_mapper.has_chemical_symbols: + atomic_nums = type_mapper.untransform(self[AtomicDataDict.ATOM_TYPE_KEY]) + else: + warnings.warn( + "AtomicData.to_ase(): self didn't contain atomic numbers... using atom_type as atomic numbers instead, but this means the chemical symbols in ASE (outputs) will be wrong" + ) + atomic_nums = self[AtomicDataDict.ATOM_TYPE_KEY] + pbc = getattr(self, AtomicDataDict.PBC_KEY, None) + cell = getattr(self, AtomicDataDict.CELL_KEY, None) + batch = getattr(self, AtomicDataDict.BATCH_KEY, None) + energy = getattr(self, AtomicDataDict.TOTAL_ENERGY_KEY, None) + energies = getattr(self, AtomicDataDict.PER_ATOM_ENERGY_KEY, None) + force = getattr(self, AtomicDataDict.FORCE_KEY, None) + do_calc = any( + k in self + for k in [ + AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.FORCE_KEY, + AtomicDataDict.PER_ATOM_ENERGY_KEY, + AtomicDataDict.STRESS_KEY, + ] + ) + + # exclude those that are special for ASE and that we process seperately + special_handling_keys = [ + AtomicDataDict.POSITIONS_KEY, + AtomicDataDict.CELL_KEY, + AtomicDataDict.PBC_KEY, + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.FORCE_KEY, + AtomicDataDict.PER_ATOM_ENERGY_KEY, + AtomicDataDict.STRESS_KEY, + ] + assert ( + len(set(extra_fields).intersection(special_handling_keys)) == 0 + ), f"Cannot specify keys handled in special ways ({special_handling_keys}) as `extra_fields` for atoms output--- they are output by default" + + if cell is not None: + cell = cell.view(-1, 3, 3) + if pbc is not None: + pbc = pbc.view(-1, 3) + + if batch is not None: + n_batches = batch.max() + 1 + cell = cell.expand(n_batches, 3, 3) if cell is not None else None + pbc = pbc.expand(n_batches, 3) if pbc is not None else None + else: + n_batches = 1 + + batch_atoms = [] + for batch_idx in range(n_batches): + if batch is not None: + mask = batch == batch_idx + mask = mask.view(-1) + # if both ends of the edge are in the batch, the edge is in the batch + edge_mask = mask[edge_index[0]] & mask[edge_index[1]] + else: + mask = slice(None) + edge_mask = slice(None) + + mol = ase.Atoms( + numbers=atomic_nums[mask].view(-1), # must be flat for ASE + positions=positions[mask], + cell=cell[batch_idx] if cell is not None else None, + pbc=pbc[batch_idx] if pbc is not None else None, + ) + + if do_calc: + fields = {} + if energies is not None: + fields["energies"] = energies[mask].cpu().numpy() + if energy is not None: + fields["energy"] = energy[batch_idx].cpu().numpy() + if force is not None: + fields["forces"] = force[mask].cpu().numpy() + if AtomicDataDict.STRESS_KEY in self: + fields["stress"] = full_3x3_to_voigt_6_stress( + self["stress"].view(-1, 3, 3)[batch_idx].cpu().numpy() + ) + mol.calc = SinglePointCalculator(mol, **fields) + + # add other information + for key in extra_fields: + if key in _NODE_FIELDS: + # mask it + mol.arrays[key] = ( + self[key][mask].cpu().numpy().reshape(mask.sum(), -1) + ) + elif key in _EDGE_FIELDS: + mol.info[key] = ( + self[key][edge_mask].cpu().numpy().reshape(edge_mask.sum(), -1) + ) + elif key == AtomicDataDict.EDGE_INDEX_KEY: + mol.info[key] = self[key][:, edge_mask].cpu().numpy() + elif key in _GRAPH_FIELDS: + mol.info[key] = self[key][batch_idx].cpu().numpy().reshape(-1) + else: + raise RuntimeError( + f"Extra field `{key}` isn't registered as node/edge/graph" + ) + + batch_atoms.append(mol) + + if batch is not None: + return batch_atoms + else: + assert len(batch_atoms) == 1 + return batch_atoms[0] + + def get_edge_vectors(data: Data) -> torch.Tensor: + data = AtomicDataDict.with_edge_vectors(AtomicData.to_AtomicDataDict(data)) + return data[AtomicDataDict.EDGE_VECTORS_KEY] + + def get_env_vectors(data: Data) -> torch.Tensor: + data = AtomicDataDict.with_env_vectors(AtomicData.to_AtomicDataDict(data)) + return data[AtomicDataDict.ENV_VECTORS_KEY] + + @staticmethod + def to_AtomicDataDict( + data: Union[Data, Mapping], exclude_keys=tuple() + ) -> AtomicDataDict.Type: + if isinstance(data, Data): + keys = data.keys + elif isinstance(data, Mapping): + keys = data.keys() + else: + raise ValueError(f"Invalid data `{repr(data)}`") + + return { + k: data[k] + for k in keys + if ( + k not in exclude_keys + and data[k] is not None + and isinstance(data[k], torch.Tensor) + ) + } + + @classmethod + def from_AtomicDataDict(cls, data: AtomicDataDict.Type): + # it's an AtomicDataDict, so don't validate-- assume valid: + return cls(_validate=False, **data) + + @property + def irreps(self): + return self.__irreps__ + + def __cat_dim__(self, key, value): + if key == AtomicDataDict.EDGE_INDEX_KEY or key == AtomicDataDict.ENV_INDEX_KEY or key == AtomicDataDict.ONSITENV_INDEX_KEY: + return 1 # always cat in the edge dimension + elif key in _GRAPH_FIELDS: + # graph-level properties and so need a new batch dimension + return None + else: + return 0 # cat along node/edge dimension + + def without_nodes(self, which_nodes): + """Return a copy of ``self`` with ``which_nodes`` removed. + The returned object may share references to some underlying data tensors with ``self``. + Args: + which_nodes (index tensor or boolean mask) + Returns: + A new data object. + """ + which_nodes = torch.as_tensor(which_nodes) + if which_nodes.dtype == torch.bool: + mask = ~which_nodes + else: + mask = torch.ones(self.num_nodes, dtype=torch.bool) + mask[which_nodes] = False + assert mask.shape == (self.num_nodes,) + n_keeping = mask.sum() + + # Only keep edges where both from and to are kept + edge_mask = mask[self.edge_index[0]] & mask[self.edge_index[1]] + if hasattr(self, AtomicDataDict.ENV_INDEX_KEY): + env_mask = mask[self.env_index[0]] & mask[self.env_index[1]] + if hasattr(self, AtomicDataDict.ONSITENV_INDEX_KEY): + onsitenv_mask = mask[self.onsitenv_index[0]] & mask[self.onsitenv_index[1]] + # Create an index mapping: + new_index = torch.full((self.num_nodes,), -1, dtype=torch.long) + new_index[mask] = torch.arange(n_keeping, dtype=torch.long) + + new_dict = {} + for k in self.keys: + if k == AtomicDataDict.EDGE_INDEX_KEY: + new_dict[AtomicDataDict.EDGE_INDEX_KEY] = new_index[ + self.edge_index[:, edge_mask] + ] + elif k == AtomicDataDict.EDGE_CELL_SHIFT_KEY: + new_dict[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = self.edge_cell_shift[ + edge_mask + ] + elif k == AtomicDataDict.CELL_KEY: + new_dict[k] = self[k] + elif k == AtomicDataDict.ENV_INDEX_KEY: + new_dict[AtomicDataDict.ENV_INDEX_KEY] = new_index[ + self.env_index[:, env_mask] + ] + elif k == AtomicDataDict.ENV_CELL_SHIFT_KEY: + new_dict[AtomicDataDict.ENV_CELL_SHIFT_KEY] = self.env_cell_shift[ + env_mask + ] + elif k == AtomicDataDict.ONSITENV_INDEX_KEY: + new_dict[AtomicDataDict.ONSITENV_INDEX_KEY] = new_index[ + self.onsitenv_index[:, onsitenv_mask] + ] + elif k == AtomicDataDict.ONSITENV_CELL_SHIFT_KEY: + new_dict[AtomicDataDict.ONSITENV_CELL_SHIFT_KEY] = self.onsitenv_cell_shift[ + onsitenv_mask + ] + else: + if isinstance(self[k], torch.Tensor) and len(self[k]) == self.num_nodes: + new_dict[k] = self[k][mask] + else: + new_dict[k] = self[k] + + new_dict["irreps"] = self.__irreps__ + + return type(self)(**new_dict) + + +_ERROR_ON_NO_EDGES: bool = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower() +assert _ERROR_ON_NO_EDGES in ("true", "false") +_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true" + + +def neighbor_list_and_relative_vec( + pos, + r_max, + self_interaction=False, + reduce=True, + atomic_numbers=None, + cell=None, + pbc=False, +): + """Create neighbor list and neighbor vectors based on radial cutoff. + + Create neighbor list (``edge_index``) and relative vectors + (``edge_attr``) based on radial cutoff. + + Edges are given by the following convention: + - ``edge_index[0]`` is the *source* (convolution center). + - ``edge_index[1]`` is the *target* (neighbor). + + Thus, ``edge_index`` has the same convention as the relative vectors: + :math:`\\vec{r}_{source, target}` + + If the input positions are a tensor with ``requires_grad == True``, + the output displacement vectors will be correctly attached to the inputs + for autograd. + + All outputs are Tensors on the same device as ``pos``; this allows future + optimization of the neighbor list on the GPU. + + Args: + pos (shape [N, 3]): Positional coordinate; Tensor or numpy array. If Tensor, must be on CPU. + r_max (float): Radial cutoff distance for neighbor finding. + cell (numpy shape [3, 3]): Cell for periodic boundary conditions. Ignored if ``pbc == False``. + pbc (bool or 3-tuple of bool): Whether the system is periodic in each of the three cell dimensions. + self_interaction (bool): Whether or not to include same periodic image self-edges in the neighbor list. + strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the two + instances of the atom are in different periodic images. Defaults to True, should be True for most applications. + + Returns: + edge_index (torch.tensor shape [2, num_edges]): List of edges. + edge_cell_shift (torch.tensor shape [num_edges, 3]): Relative cell shift + vectors. Returned only if cell is not None. + cell (torch.Tensor [3, 3]): the cell as a tensor on the correct device. + Returned only if cell is not None. + """ + if isinstance(pbc, bool): + pbc = (pbc,) * 3 + + # Either the position or the cell may be on the GPU as tensors + if isinstance(pos, torch.Tensor): + temp_pos = pos.detach().cpu().numpy() + out_device = pos.device + out_dtype = pos.dtype + else: + temp_pos = np.asarray(pos) + out_device = torch.device("cpu") + out_dtype = torch.get_default_dtype() + + # Right now, GPU tensors require a round trip + if out_device.type != "cpu": + warnings.warn( + "Currently, neighborlists require a round trip to the CPU. Please pass CPU tensors if possible." + ) + + # Get a cell on the CPU no matter what + if isinstance(cell, torch.Tensor): + temp_cell = cell.detach().cpu().numpy() + cell_tensor = cell.to(device=out_device, dtype=out_dtype) + elif cell is not None: + temp_cell = np.asarray(cell) + cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) + else: + # ASE will "complete" this correctly. + temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype) + cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) + + # ASE dependent part + temp_cell = ase.geometry.complete_cell(temp_cell) + + first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( + "ijS", + pbc, + temp_cell, + temp_pos, + cutoff=float(r_max), + self_interaction=self_interaction, # we want edges from atom to itself in different periodic images! + use_scaled_positions=False, + ) + + # Eliminate true self-edges that don't cross periodic boundaries + # if not self_interaction: + # bad_edge = first_idex == second_idex + # bad_edge &= np.all(shifts == 0, axis=1) + # keep_edge = ~bad_edge + # if _ERROR_ON_NO_EDGES and (not np.any(keep_edge)): + # raise ValueError( + # f"Every single atom has no neighbors within the cutoff r_max={r_max} (after eliminating self edges, no edges remain in this system)" + # ) + # first_idex = first_idex[keep_edge] + # second_idex = second_idex[keep_edge] + # shifts = shifts[keep_edge] + + """ + bond list is: i, j, shift; but i j shift and j i -shift are the same bond. so we need to remove the duplicate bonds.s + first for i != j; we only keep i < j; then the j i -shift will be removed. + then, for i == j; we only keep i i shift and remove i i -shift. + """ + # 1. for i != j, keep i < j + assert atomic_numbers is not None + atomic_numbers = torch.as_tensor(atomic_numbers, dtype=torch.long) + mask = first_idex <= second_idex + first_idex = first_idex[mask] + second_idex = second_idex[mask] + shifts = shifts[mask] + + # 2. for i == j + + mask = torch.ones(len(first_idex), dtype=torch.bool) + mask[first_idex == second_idex] = False + # get index bool type ~mask for i == j. + o_first_idex = first_idex[~mask] + o_second_idex = second_idex[~mask] + o_shift = shifts[~mask] + o_mask = mask[~mask] # this is all False, with length being the number all the bonds with i == j. + + + # using the dict key to remove the duplicate bonds, because it is O(1) to check if a key is in the dict. + rev_dict = {} + for i in range(len(o_first_idex)): + key = str(o_first_idex[i])+str(o_shift[i]) + key_rev = str(o_first_idex[i])+str(-o_shift[i]) + rev_dict[key] = True + # key_rev is the reverse key of key, if key_rev is in the dict, then the bond is duplicate. + # so, only when key_rev is not in the dict, we keep the bond. that is when rev_dict.get(key_rev, False) is False, we set o_mast = True. + if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)): + o_mask[i] = True + del rev_dict + del o_first_idex + del o_second_idex + del o_shift + mask[~mask] = o_mask + del o_mask + + first_idex = torch.LongTensor(first_idex[mask], device=out_device) + second_idex = torch.LongTensor(second_idex[mask], device=out_device) + shifts = torch.as_tensor(shifts[mask], dtype=out_dtype, device=out_device) + + if not reduce: + first_idex, second_idex = torch.cat((first_idex, second_idex), dim=0), torch.cat((second_idex, first_idex), dim=0) + shifts = torch.cat((shifts, -shifts), dim=0) + + # Build output: + edge_index = torch.vstack( + (torch.LongTensor(first_idex), torch.LongTensor(second_idex)) + ) + + return edge_index, shifts, cell_tensor diff --git a/dptb/data/AtomicDataDict.py b/dptb/data/AtomicDataDict.py new file mode 100644 index 00000000..6fc45c9b --- /dev/null +++ b/dptb/data/AtomicDataDict.py @@ -0,0 +1,233 @@ +"""nequip.data.jit: TorchScript functions for dealing with AtomicData. + +These TorchScript functions operate on ``Dict[str, torch.Tensor]`` representations +of the ``AtomicData`` class which are produced by ``AtomicData.to_AtomicDataDict()``. + +Authors: Albert Musaelian +""" +from typing import Dict, Any + +import torch +import torch.jit + +from e3nn import o3 + +# Make the keys available in this module +from ._keys import * # noqa: F403, F401 + +# Also import the module to use in TorchScript, this is a hack to avoid bug: +# https://github.com/pytorch/pytorch/issues/52312 +from . import _keys + +# Define a type alias +Type = Dict[str, torch.Tensor] + + +def validate_keys(keys, graph_required=True): + # Validate combinations + if graph_required: + if not (_keys.POSITIONS_KEY in keys and _keys.EDGE_INDEX_KEY in keys): + raise KeyError("At least pos and edge_index must be supplied") + if _keys.EDGE_CELL_SHIFT_KEY in keys and "cell" not in keys: + raise ValueError("If `edge_cell_shift` given, `cell` must be given.") + + +_SPECIAL_IRREPS = [None] + + +def _fix_irreps_dict(d: Dict[str, Any]): + return {k: (i if i in _SPECIAL_IRREPS else o3.Irreps(i)) for k, i in d.items()} + + +def _irreps_compatible(ir1: Dict[str, o3.Irreps], ir2: Dict[str, o3.Irreps]): + return all(ir1[k] == ir2[k] for k in ir1 if k in ir2) + + +@torch.jit.script +def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type: + """Compute the edge displacement vectors for a graph. + + If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this + method will return edge vectors correctly connected in the autograd graph. + + Returns: + Tensor [n_edges, 3] edge displacement vectors + """ + if _keys.EDGE_VECTORS_KEY in data: + if with_lengths and _keys.EDGE_LENGTH_KEY not in data: + data[_keys.EDGE_LENGTH_KEY] = torch.linalg.norm( + data[_keys.EDGE_VECTORS_KEY], dim=-1 + ) + + return data + else: + # Build it dynamically + # Note that this is + # (1) backwardable, because everything (pos, cell, shifts) + # is Tensors. + # (2) works on a Batch constructed from AtomicData + pos = data[_keys.POSITIONS_KEY] + edge_index = data[_keys.EDGE_INDEX_KEY] + edge_vec = pos[edge_index[1]] - pos[edge_index[0]] + if _keys.CELL_KEY in data: + # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero. + # -1 gives a batch dim no matter what + cell = data[_keys.CELL_KEY].view(-1, 3, 3) + edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY] + if cell.shape[0] > 1: + batch = data[_keys.BATCH_KEY] + # Cell has a batch dimension + # note the ASE cell vectors as rows convention + edge_vec = edge_vec + torch.einsum( + "ni,nij->nj", edge_cell_shift, cell[batch[edge_index[0]]] + ) + # TODO: is there a more efficient way to do the above without + # creating an [n_edge] and [n_edge, 3, 3] tensor? + else: + # Cell has either no batch dimension, or a useless one, + # so we can avoid creating the large intermediate cell tensor. + # Note that we do NOT check that the batch array, if it is present, + # is trivial — but this does need to be consistent. + edge_vec = edge_vec + torch.einsum( + "ni,ij->nj", + edge_cell_shift, + cell.squeeze(0), # remove batch dimension + ) + + data[_keys.EDGE_VECTORS_KEY] = edge_vec + if with_lengths: + data[_keys.EDGE_LENGTH_KEY] = torch.linalg.norm(edge_vec, dim=-1) + return data + +@torch.jit.script +def with_env_vectors(data: Type, with_lengths: bool = True) -> Type: + """Compute the edge displacement vectors for a graph. + + If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this + method will return edge vectors correctly connected in the autograd graph. + + Returns: + Tensor [n_edges, 3] edge displacement vectors + """ + if _keys.ENV_VECTORS_KEY in data: + if with_lengths and _keys.ENV_LENGTH_KEY not in data: + data[_keys.ENV_LENGTH_KEY] = torch.linalg.norm( + data[_keys.ENV_VECTORS_KEY], dim=-1 + ) + return data + else: + # Build it dynamically + # Note that this is + # (1) backwardable, because everything (pos, cell, shifts) + # is Tensors. + # (2) works on a Batch constructed from AtomicData + pos = data[_keys.POSITIONS_KEY] + env_index = data[_keys.ENV_INDEX_KEY] + env_vec = pos[env_index[1]] - pos[env_index[0]] + if _keys.CELL_KEY in data: + # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero. + # -1 gives a batch dim no matter what + cell = data[_keys.CELL_KEY].view(-1, 3, 3) + env_cell_shift = data[_keys.ENV_CELL_SHIFT_KEY] + if cell.shape[0] > 1: + batch = data[_keys.BATCH_KEY] + # Cell has a batch dimension + # note the ASE cell vectors as rows convention + env_vec = env_vec + torch.einsum( + "ni,nij->nj", env_cell_shift, cell[batch[env_index[0]]] + ) + # TODO: is there a more efficient way to do the above without + # creating an [n_edge] and [n_edge, 3, 3] tensor? + else: + # Cell has either no batch dimension, or a useless one, + # so we can avoid creating the large intermediate cell tensor. + # Note that we do NOT check that the batch array, if it is present, + # is trivial — but this does need to be consistent. + env_vec = env_vec + torch.einsum( + "ni,ij->nj", + env_cell_shift, + cell.squeeze(0), # remove batch dimension + ) + data[_keys.ENV_VECTORS_KEY] = env_vec + if with_lengths: + data[_keys.ENV_LENGTH_KEY] = torch.linalg.norm(env_vec, dim=-1) + return data + +@torch.jit.script +def with_onsitenv_vectors(data: Type, with_lengths: bool = True) -> Type: + """Compute the edge displacement vectors for a graph. + + If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this + method will return edge vectors correctly connected in the autograd graph. + + Returns: + Tensor [n_edges, 3] edge displacement vectors + """ + if _keys.ONSITENV_VECTORS_KEY in data: + if with_lengths and _keys.ONSITENV_LENGTH_KEY not in data: + data[_keys.ONSITENV_LENGTH_KEY] = torch.linalg.norm( + data[_keys.ONSITENV_VECTORS_KEY], dim=-1 + ) + return data + else: + # Build it dynamically + # Note that this is + # (1) backwardable, because everything (pos, cell, shifts) + # is Tensors. + # (2) works on a Batch constructed from AtomicData + pos = data[_keys.POSITIONS_KEY] + env_index = data[_keys.ONSITENV_INDEX_KEY] + env_vec = pos[env_index[1]] - pos[env_index[0]] + if _keys.CELL_KEY in data: + # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero. + # -1 gives a batch dim no matter what + cell = data[_keys.CELL_KEY].view(-1, 3, 3) + env_cell_shift = data[_keys.ONSITENV_CELL_SHIFT_KEY] + if cell.shape[0] > 1: + batch = data[_keys.BATCH_KEY] + # Cell has a batch dimension + # note the ASE cell vectors as rows convention + env_vec = env_vec + torch.einsum( + "ni,nij->nj", env_cell_shift, cell[batch[env_index[0]]] + ) + # TODO: is there a more efficient way to do the above without + # creating an [n_edge] and [n_edge, 3, 3] tensor? + else: + # Cell has either no batch dimension, or a useless one, + # so we can avoid creating the large intermediate cell tensor. + # Note that we do NOT check that the batch array, if it is present, + # is trivial — but this does need to be consistent. + env_vec = env_vec + torch.einsum( + "ni,ij->nj", + env_cell_shift, + cell.squeeze(0), # remove batch dimension + ) + data[_keys.ONSITENV_VECTORS_KEY] = env_vec + if with_lengths: + data[_keys.ONSITENV_LENGTH_KEY] = torch.linalg.norm(env_vec, dim=-1) + return data + + +@torch.jit.script +def with_batch(data: Type) -> Type: + """Get batch Tensor. + + If this AtomicDataPrimitive has no ``batch``, one of all zeros will be + allocated and returned. + """ + if _keys.BATCH_KEY in data: + return data + else: + pos = data[_keys.POSITIONS_KEY] + batch = torch.zeros(len(pos), dtype=torch.long, device=pos.device) + data[_keys.BATCH_KEY] = batch + # ugly way to make a tensor of [0, len(pos)], but it avoids transfers or casts + data[_keys.BATCH_PTR_KEY] = torch.arange( + start=0, + end=len(pos) + 1, + step=len(pos), + dtype=torch.long, + device=pos.device, + ) + + return data diff --git a/dptb/data/__init__.py b/dptb/data/__init__.py new file mode 100644 index 00000000..2efca699 --- /dev/null +++ b/dptb/data/__init__.py @@ -0,0 +1,49 @@ +from .AtomicData import ( + AtomicData, + PBC, + register_fields, + deregister_fields, + _register_field_prefix, + _NODE_FIELDS, + _EDGE_FIELDS, + _GRAPH_FIELDS, + _LONG_FIELDS, +) +from .dataset import ( + AtomicDataset, + AtomicInMemoryDataset, + NpzDataset, + ASEDataset, + HDF5Dataset, + ABACUSDataset, + ABACUSInMemoryDataset, + DefaultDataset +) +from .dataloader import DataLoader, Collater, PartialSampler +from .build import dataset_from_config +from .test_data import EMTTestDataset + +__all__ = [ + AtomicData, + PBC, + register_fields, + deregister_fields, + _register_field_prefix, + AtomicDataset, + AtomicInMemoryDataset, + NpzDataset, + ASEDataset, + HDF5Dataset, + ABACUSDataset, + ABACUSInMemoryDataset, + DefaultDataset, + DataLoader, + Collater, + PartialSampler, + dataset_from_config, + _NODE_FIELDS, + _EDGE_FIELDS, + _GRAPH_FIELDS, + _LONG_FIELDS, + EMTTestDataset, +] diff --git a/dptb/data/_keys.py b/dptb/data/_keys.py new file mode 100644 index 00000000..187d4c79 --- /dev/null +++ b/dptb/data/_keys.py @@ -0,0 +1,125 @@ +"""Keys for dictionaries/AtomicData objects. + +This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module. +""" + +import sys +from typing import List + +if sys.version_info[1] >= 8: + from typing import Final +else: + from typing_extensions import Final + +# == Define allowed keys as constants == +# The positions of the atoms in the system +POSITIONS_KEY: Final[str] = "pos" +# The [2, n_edge] index tensor giving center -> neighbor relations +EDGE_INDEX_KEY: Final[str] = "edge_index" +# The [2, n_env] index tensor giving center -> neighbor relations +ENV_INDEX_KEY: Final[str] = "env_index" +# The [2, n_onsitenv] index tensor giving center -> neighbor relations +ONSITENV_INDEX_KEY: Final[str] = "onsitenv_index" +# A [n_edge, 3] tensor of how many periodic cells each env crosses in each cell vector +ENV_CELL_SHIFT_KEY: Final[str] = "env_cell_shift" +# A [n_edge, 3] tensor of how many periodic cells each edge crosses in each cell vector +EDGE_CELL_SHIFT_KEY: Final[str] = "edge_cell_shift" +# [n_batch, 3, 3] or [3, 3] tensor where rows are the cell vectors +ONSITENV_CELL_SHIFT_KEY: Final[str] = "onsitenv_cell_shift" +# [n_batch, 3, 3] or [3, 3] tensor where rows are the cell vectors +CELL_KEY: Final[str] = "cell" +# [n_kpoints, 3] or [n_batch, nkpoints, 3] tensor +KPOINT_KEY = "kpoint" + +HAMILTONIAN_KEY = "hamiltonian" + +OVERLAP_KEY = "overlap" +# [n_batch, 3] bool tensor +PBC_KEY: Final[str] = "pbc" +# [n_atom, 1] long tensor +ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers" +# [n_atom, 1] long tensor +ATOM_TYPE_KEY: Final[str] = "atom_types" +# [n_batch, n_kpoint, n_orb] +ENERGY_EIGENVALUE_KEY: Final[str] = "eigenvalue" + +# [n_batch, 2] +ENERGY_WINDOWS_KEY = "ewindow" +BAND_WINDOW_KEY = "bwindow" + +BASIC_STRUCTURE_KEYS: Final[List[str]] = [ + POSITIONS_KEY, + EDGE_INDEX_KEY, + EDGE_CELL_SHIFT_KEY, + CELL_KEY, + PBC_KEY, + ATOM_TYPE_KEY, + ATOMIC_NUMBERS_KEY, +] + +# A [n_edge, 3] tensor of displacement vectors associated to edges +EDGE_VECTORS_KEY: Final[str] = "edge_vectors" +# A [n_edge, 3] tensor of displacement vectors associated to envs +ENV_VECTORS_KEY: Final[str] = "env_vectors" +# A [n_edge, 3] tensor of displacement vectors associated to onsitenvs +ONSITENV_VECTORS_KEY: Final[str] = "onsitenv_vectors" +# A [n_edge] tensor of the lengths of EDGE_VECTORS +EDGE_LENGTH_KEY: Final[str] = "edge_lengths" +# A [n_edge] tensor of the lengths of ENV_VECTORS +ENV_LENGTH_KEY: Final[str] = "env_lengths" +# A [n_edge] tensor of the lengths of ONSITENV_VECTORS +ONSITENV_LENGTH_KEY: Final[str] = "onsitenv_lengths" +# [n_edge, dim] (possibly equivariant) attributes of each edge +EDGE_ATTRS_KEY: Final[str] = "edge_attrs" +ENV_ATTRS_KEY: Final[str] = "env_attrs" +ONSITENV_ATTRS_KEY: Final[str] = "onsitenv_attrs" +# [n_edge, dim] invariant embedding of the edges +EDGE_EMBEDDING_KEY: Final[str] = "edge_embedding" +ENV_EMBEDDING_KEY: Final[str] = "env_embedding" +ONSITENV_EMBEDDING_KEY: Final[str] = "onsitenv_embedding" +EDGE_FEATURES_KEY: Final[str] = "edge_features" +ENV_FEATURES_KEY: Final[str] = "env_features" +ONSITENV_FEATURES_KEY: Final[str] = "onsitenv_features" +# [n_edge, 1] invariant of the radial cutoff envelope for each edge, allows reuse of cutoff envelopes +EDGE_CUTOFF_KEY: Final[str] = "edge_cutoff" +# [n_edge, 1] invariant of the radial cutoff envelope for each env edge, allows reuse of cutoff envelopes +ENV_CUTOFF_KEY: Final[str] = "env_cutoff" +# [n_edge, 1] invariant of the radial cutoff envelope for each onsitenv edge, allows reuse of cutoff envelopes +ONSITENV_CUTOFF_KEY: Final[str] = "onsitenv_cutoff" +# edge energy as in Allegro +EDGE_ENERGY_KEY: Final[str] = "edge_energy" +EDGE_OVERLAP_KEY: Final[str] = "edge_overlap" +NODE_OVERLAP_KEY: Final[str] = "node_overlap" +EDGE_HAMILTONIAN_KEY: Final[str] = "edge_hamiltonian" +NODE_HAMILTONIAN_KEY: Final[str] = "node_hamiltonian" + +NODE_FEATURES_KEY: Final[str] = "node_features" +NODE_ATTRS_KEY: Final[str] = "node_attrs" +EDGE_TYPE_KEY: Final[str] = "edge_type" + +PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy" +TOTAL_ENERGY_KEY: Final[str] = "total_energy" +FORCE_KEY: Final[str] = "forces" +PARTIAL_FORCE_KEY: Final[str] = "partial_forces" +STRESS_KEY: Final[str] = "stress" +VIRIAL_KEY: Final[str] = "virial" + +ALL_ENERGY_KEYS: Final[List[str]] = [ + EDGE_ENERGY_KEY, + PER_ATOM_ENERGY_KEY, + TOTAL_ENERGY_KEY, + FORCE_KEY, + PARTIAL_FORCE_KEY, + STRESS_KEY, + VIRIAL_KEY, +] + +BATCH_KEY: Final[str] = "batch" +BATCH_PTR_KEY: Final[str] = "ptr" + +# Make a list of allowed keys +ALLOWED_KEYS: List[str] = [ + getattr(sys.modules[__name__], k) + for k in sys.modules[__name__].__dict__.keys() + if k.endswith("_KEY") +] diff --git a/dptb/data/build.py b/dptb/data/build.py new file mode 100644 index 00000000..f419c15f --- /dev/null +++ b/dptb/data/build.py @@ -0,0 +1,191 @@ +import inspect +import os +from copy import deepcopy +import glob +from importlib import import_module + +from dptb.data.dataset import DefaultDataset +from dptb import data +from dptb.data.transforms import TypeMapper, OrbitalMapper +from dptb.data import AtomicDataset, register_fields +from dptb.utils import instantiate, get_w_prefix +from dptb.utils.tools import j_loader +from dptb.utils.argcheck import normalize_setinfo + + +def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: + """initialize database based on a config instance + + It needs dataset type name (case insensitive), + and all the parameters needed in the constructor. + + Examples see tests/data/test_dataset.py TestFromConfig + and tests/datasets/test_simplest.py + + Args: + + config (dict, nequip.utils.Config): dict/object that store all the parameters + prefix (str): Optional. The prefix of all dataset parameters + + Return: + + dataset (nequip.data.AtomicDataset) + """ + + config_dataset = config.get(prefix, None) + if config_dataset is None: + raise KeyError(f"Dataset with prefix `{prefix}` isn't present in this config!") + + if inspect.isclass(config_dataset): + # user define class + class_name = config_dataset + else: + try: + module_name = ".".join(config_dataset.split(".")[:-1]) + class_name = ".".join(config_dataset.split(".")[-1:]) + class_name = getattr(import_module(module_name), class_name) + except Exception: + # ^ TODO: don't catch all Exception + # default class defined in nequip.data or nequip.dataset + dataset_name = config_dataset.lower() + + class_name = None + for k, v in inspect.getmembers(data, inspect.isclass): + if k.endswith("Dataset"): + if k.lower() == dataset_name: + class_name = v + if k[:-7].lower() == dataset_name: + class_name = v + elif k.lower() == dataset_name: + class_name = v + + if class_name is None: + raise NameError(f"dataset type {dataset_name} does not exists") + + # if dataset r_max is not found, use the universal r_max + atomicdata_options_key = "AtomicData_options" + prefixed_eff_key = f"{prefix}_{atomicdata_options_key}" + config[prefixed_eff_key] = get_w_prefix( + atomicdata_options_key, {}, prefix=prefix, arg_dicts=config + ) + config[prefixed_eff_key]["r_max"] = get_w_prefix( + "r_max", + prefix=prefix, + arg_dicts=[config[prefixed_eff_key], config], + ) + + config[prefixed_eff_key]["er_max"] = get_w_prefix( + "er_max", + prefix=prefix, + arg_dicts=[config[prefixed_eff_key], config], + ) + + config[prefixed_eff_key]["oer_max"] = get_w_prefix( + "oer_max", + prefix=prefix, + arg_dicts=[config[prefixed_eff_key], config], + ) + + # Build a TypeMapper from the config + type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) + + # Register fields: + # This might reregister fields, but that's OK: + instantiate(register_fields, all_args=config) + + instance, _ = instantiate( + class_name, + prefix=prefix, + positional_args={"type_mapper": type_mapper}, + optional_args=config, + ) + + return instance + + +def build_dataset(set_options, common_options): + + dataset_type = set_options.get("type", "DefaultDataset") + + # input in set_option for Default Dataset: + # "root": main dir storing all trajectory folders. + # that is, each subfolder of root contains a trajectory. + # "prefix": optional, load selected trajectory folders. + # "get_Hamiltonian": load the Hamiltonian file to edges of the graph or not. + # "get_eigenvalues": load the eigenvalues to the graph or not. + # "setinfo": MUST HAVE, the name of the json file used to build dataset. + # Example: + # "train": { + # "type": "DefaultDataset", + # "root": "foo/bar/data_files_here", + # "prefix": "traj", + # "setinfo": "with_pbc.json" + # } + if dataset_type == "DefaultDataset": + # See if we can get a OrbitalMapper. + if "basis" in common_options: + idp = OrbitalMapper(common_options["basis"]) + else: + idp = None + + # Explore the dataset's folder structure. + root = set_options["root"] + prefix = set_options.get("prefix", None) + include_folders = [] + for dir_name in os.listdir(root): + dir_path = os.path.join(root, dir_name) + if os.path.isdir(dir_path): + # If the `processed_dataset` or other folder is here too, they do not have the proper traj data files. + # And we will have problem in generating TrajData! + # So we test it here: the data folder must have `.dat` or `.traj` file. + if glob.glob(os.path.join(dir_path, '*.dat')) or glob.glob(os.path.join(dir_path, '*.traj')): + if prefix is not None: + if dir_name[:len(prefix)] == prefix: + include_folders.append(dir_name) + else: + include_folders.append(dir_name) + + # We need to check the `setinfo.json` very carefully here. + # Different `setinfo` points to different dataset, + # even if the data files in `root` are basically the same. + info_files = {} + + # See if a public info is provided. + if "info.json" in os.listdir(root): + public_info = j_loader(os.path.join(root, "info.json")) + public_info = normalize_setinfo(public_info) + print("A public `info.json` file is provided, and will be used by the subfolders who do not have their own `info.json` file.") + else: + public_info = None + + # Load info in each trajectory folders seperately. + for file in include_folders: + if "info.json" in os.listdir(os.path.join(root, file)): + # use info provided in this trajectory. + info = j_loader(os.path.join(root, file, "info.json")) + info = normalize_setinfo(info) + info_files[file] = info + elif public_info is not None: + # use public info instead + # yaml will not dump correctly if this is not a deepcopy. + info_files[file] = deepcopy(public_info) + else: + # no info for this file + raise Exception(f"info.json is not properly provided for `{file}`.") + + # We will sort the info_files here. + # The order itself is not important, but must be consistant for the same list. + info_files = {key: info_files[key] for key in sorted(info_files)} + + dataset = DefaultDataset( + root=root, + type_mapper=idp, + get_Hamiltonian=set_options.get("get_Hamiltonian", False), + get_eigenvalues=set_options.get("get_eigenvalues", False), + info_files = info_files + ) + + else: + raise ValueError(f"Not support dataset type: {type}.") + + return dataset diff --git a/dptb/data/dataloader.py b/dptb/data/dataloader.py new file mode 100644 index 00000000..1f383f1a --- /dev/null +++ b/dptb/data/dataloader.py @@ -0,0 +1,163 @@ +from typing import List, Optional, Iterator + +import torch +from torch.utils.data import Sampler + +from dptb.utils.torch_geometric import Batch, Data, Dataset + +class Collater(object): + """Collate a list of ``AtomicData``. + + Args: + exclude_keys: keys to ignore in the input, not copying to the output + """ + + def __init__( + self, + exclude_keys: List[str] = [], + ): + self._exclude_keys = set(exclude_keys) + + @classmethod + def for_dataset( + cls, + dataset, + exclude_keys: List[str] = [], + ): + """Construct a collater appropriate to ``dataset``.""" + return cls( + exclude_keys=exclude_keys, + ) + + def collate(self, batch: List[Data]) -> Batch: + """Collate a list of data""" + out = Batch.from_data_list(batch, exclude_keys=self._exclude_keys) + return out + + def __call__(self, batch: List[Data]) -> Batch: + """Collate a list of data""" + return self.collate(batch) + + @property + def exclude_keys(self): + return list(self._exclude_keys) + + +class DataLoader(torch.utils.data.DataLoader): + def __init__( + self, + dataset, + batch_size: int = 1, + shuffle: bool = False, + exclude_keys: List[str] = [], + **kwargs, + ): + if "collate_fn" in kwargs: + del kwargs["collate_fn"] + + super(DataLoader, self).__init__( + dataset, + batch_size, + shuffle, + collate_fn=Collater.for_dataset(dataset, exclude_keys=exclude_keys), + **kwargs, + ) + + +class PartialSampler(Sampler[int]): + r"""Samples elements without replacement, but divided across a number of calls to `__iter__`. + + To ensure deterministic reproducibility and restartability, dataset permutations are generated + from a combination of the overall seed and the epoch number. As a result, the caller must + tell this sampler the epoch number before each time `__iter__` is called by calling + `my_partial_sampler.step_epoch(epoch_number_about_to_run)` each time. + + This sampler decouples epochs from the dataset size and cycles through the dataset over as + many (partial) epochs as it may take. As a result, the _dataset_ epoch can change partway + through a training epoch. + + Args: + data_source (Dataset): dataset to sample from + shuffle (bool): whether to shuffle the dataset each time the _entire_ dataset is consumed + num_samples_per_epoch (int): number of samples to draw in each call to `__iter__`. + If `None`, defaults to `len(data_source)`. + generator (Generator): Generator used in sampling. + """ + data_source: Dataset + num_samples_per_epoch: int + shuffle: bool + _epoch: int + _prev_epoch: int + + def __init__( + self, + data_source: Dataset, + shuffle: bool = True, + num_samples_per_epoch: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.shuffle = shuffle + if num_samples_per_epoch is None: + num_samples_per_epoch = self.num_samples_total + self.num_samples_per_epoch = num_samples_per_epoch + assert self.num_samples_per_epoch <= self.num_samples_total + assert self.num_samples_per_epoch >= 1 + self.generator = generator + self._epoch = None + self._prev_epoch = None + + @property + def num_samples_total(self) -> int: + # dataset size might change at runtime + return len(self.data_source) + + def step_epoch(self, epoch: int) -> None: + self._epoch = epoch + + def __iter__(self) -> Iterator[int]: + assert self._epoch is not None + assert (self._prev_epoch is None) or (self._epoch == self._prev_epoch + 1) + assert self._epoch >= 0 + + full_epoch_i, start_sample_i = divmod( + # how much data we've already consumed: + self._epoch * self.num_samples_per_epoch, + # how much data there is the dataset: + self.num_samples_total, + ) + + if self.shuffle: + temp_rng = torch.Generator() + # Get new randomness for each _full_ time through the dataset + # This is deterministic w.r.t. the combination of dataset seed and epoch number + # Both of which persist across restarts + # (initial_seed() is restored by set_state()) + temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i) + full_order_this = torch.randperm(self.num_samples_total, generator=temp_rng) + # reseed the generator for the _next_ epoch to get the shuffled order of the + # _next_ dataset epoch to pad out this one for completing any partial batches + # at the end: + temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i + 1) + full_order_next = torch.randperm(self.num_samples_total, generator=temp_rng) + del temp_rng + else: + full_order_this = torch.arange(self.num_samples_total) + # without shuffling, the next epoch has the same sampling order as this one: + full_order_next = full_order_this + + full_order = torch.cat((full_order_this, full_order_next), dim=0) + del full_order_next, full_order_this + + this_segment_indexes = full_order[ + start_sample_i : start_sample_i + self.num_samples_per_epoch + ] + # because we cycle into indexes from the next dataset epoch, + # we should _always_ be able to get num_samples_per_epoch + assert len(this_segment_indexes) == self.num_samples_per_epoch + yield from this_segment_indexes + + self._prev_epoch = self._epoch + + def __len__(self) -> int: + return self.num_samples_per_epoch diff --git a/dptb/data/dataset/__init__.py b/dptb/data/dataset/__init__.py new file mode 100644 index 00000000..c8dca670 --- /dev/null +++ b/dptb/data/dataset/__init__.py @@ -0,0 +1,21 @@ +from ._base_datasets import AtomicDataset, AtomicInMemoryDataset +from ._ase_dataset import ASEDataset +from ._npz_dataset import NpzDataset +from ._hdf5_dataset import HDF5Dataset +from ._abacus_dataset import ABACUSDataset, ABACUSInMemoryDataset +from ._deeph_dataset import DeePHE3Dataset +from ._default_dataset import DefaultDataset + + +__all__ = [ + DefaultDataset, + DeePHE3Dataset, + ABACUSInMemoryDataset, + ABACUSDataset, + ASEDataset, + AtomicDataset, + AtomicInMemoryDataset, + NpzDataset, + HDF5Dataset + ] + diff --git a/dptb/data/dataset/_abacus_dataset.py b/dptb/data/dataset/_abacus_dataset.py new file mode 100644 index 00000000..bcdb1c62 --- /dev/null +++ b/dptb/data/dataset/_abacus_dataset.py @@ -0,0 +1,109 @@ +from typing import Dict, Any, List, Callable, Union, Optional +import os + +import numpy as np +import h5py + +import torch + +from .. import ( + AtomicData, + AtomicDataDict, +) + +from ..transforms import TypeMapper, OrbitalMapper +from ._base_datasets import AtomicDataset, AtomicInMemoryDataset +#from dptb.nn.hamiltonian import E3Hamiltonian +from dptb.data.interfaces.ham_to_feature import ham_block_to_feature + +orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"} + +def _abacus_h5_reader(h5file_path, AtomicData_options): + data = h5py.File(h5file_path, "r") + atomic_data = AtomicData.from_points( + pos = data["pos"][:], + cell = data["cell"][:], + atomic_numbers = data["atomic_numbers"][:], + **AtomicData_options, + ) + if "hamiltonian_blocks" in data: + basis = {} + for key, value in data["basis"].items(): + basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)] + idp = OrbitalMapper(basis) + # e3 = E3Hamiltonian(idp=idp, decompose=True) + ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False)) + # with torch.no_grad(): + # atomic_data = e3(atomic_data.to_dict()) + # atomic_data = AtomicData.from_dict(atomic_data) + + if "eigenvalues" in data and "kpionts" in data: + atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoints"][:], dtype=torch.get_default_dtype()) + atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalues"][:], dtype=torch.get_default_dtype()) + return atomic_data + +# Lazy loading class, built for large dataset. + +class ABACUSDataset(AtomicDataset): + + def __init__( + self, + root: str, + preprocess_dir: str, + AtomicData_options: Dict[str, Any] = {}, + type_mapper: Optional[TypeMapper] = None, + ): + super().__init__(root=root, type_mapper=type_mapper) + self.preprocess_dir = preprocess_dir + self.file_name = np.loadtxt(os.path.join(self.preprocess_dir, 'AtomicData_file.txt'), dtype=str) + self.AtomicData_options = AtomicData_options + self.num_examples = len(self.file_name) + + def get(self, idx): + name = self.file_name[idx] + h5_file = os.path.join(self.preprocess_dir, name) + atomic_data = _abacus_h5_reader(h5_file, self.AtomicData_options) + return atomic_data + + def len(self) -> int: + return self.num_examples + +# In memory version. + +class ABACUSInMemoryDataset(AtomicInMemoryDataset): + + def __init__( + self, + root: str, + preprocess_dir: str, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + ): + self.preprocess_dir = preprocess_dir + self.file_name = np.loadtxt(os.path.join(self.preprocess_dir, 'AtomicData_file.txt'), dtype=str) + + super(ABACUSInMemoryDataset, self).__init__( + file_name=self.file_name, + url=url, + root=root, + AtomicData_options=AtomicData_options, + include_frames=include_frames, + type_mapper=type_mapper, + ) + + def get_data(self): + data = [] + for name in self.file_name: + h5_file = os.path.join(self.preprocess_dir, name) + data.append(_abacus_h5_reader(h5_file, self.AtomicData_options)) + return data + + @property + def raw_file_names(self): + return "AtomicData.h5" + + @property + def raw_dir(self): + return self.root \ No newline at end of file diff --git a/dptb/data/dataset/_abacus_dataset_mem.py b/dptb/data/dataset/_abacus_dataset_mem.py new file mode 100644 index 00000000..de4263ae --- /dev/null +++ b/dptb/data/dataset/_abacus_dataset_mem.py @@ -0,0 +1,109 @@ +from typing import Dict, Any, List, Callable, Union, Optional +import os + +import numpy as np +import h5py + +import torch + +from .. import ( + AtomicData, + AtomicDataDict, +) +from ..transforms import TypeMapper, OrbitalMapper +from ._base_datasets import AtomicInMemoryDataset +from dptb.nn.hamiltonian import E3Hamiltonian +from dptb.data.interfaces.ham_to_feature import ham_block_to_feature +from dptb.data.interfaces.abacus import recursive_parse + +orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"} + +def _abacus_h5_reader(h5file_path, AtomicData_options): + data = h5py.File(h5file_path, "r") + atomic_data = AtomicData.from_points( + pos = data["pos"][:], + cell = data["cell"][:], + atomic_numbers = data["atomic_numbers"][:], + **AtomicData_options, + ) + if data["hamiltonian_blocks"]: + basis = {} + for key, value in data["basis"].items(): + basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)] + idp = OrbitalMapper(basis) + # e3 = E3Hamiltonian(idp=idp, decompose=True) + ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False)) + # with torch.no_grad(): + # atomic_data = e3(atomic_data.to_dict()) + # atomic_data = AtomicData.from_dict(atomic_data) + + if data.get("eigenvalue") and data.get("kpoint"): + atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoint"][:], dtype=torch.get_default_dtype()) + atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalue"][:], dtype=torch.get_default_dtype()) + return atomic_data + + +class ABACUSInMemoryDataset(AtomicInMemoryDataset): + + def __init__( + self, + root: str, + abacus_args: Dict[str, Union[str,bool]] = { + "input_dir": None, + "preprocess_dir": None, + "only_overlap": False, + "get_Ham": False, + "add_overlap": False, + "get_eigenvalues": False, + }, + file_name: Optional[str] = None, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + key_mapping: Dict[str, str] = { + "pos": AtomicDataDict.POSITIONS_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + "atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY, + "kpoints": AtomicDataDict.KPOINT_KEY, + "eigenvalues": AtomicDataDict.ENERGY_EIGENVALUE_KEY, + }, + ): + if file_name is not None: + self.file_name = file_name + else: + self.abacus_args = abacus_args + assert self.abacus_args.get("input_dir") is not None, "ABACUS calculation results MUST be provided." + if self.abacus_args.get("preprocess_dir") is None: + print("Creating new preprocess dictionary...") + os.mkdir(os.path.join(root, "preprocess")) + self.abacus_args["preprocess_dir"] = os.path.join(root, "preprocess") + self.key_mapping = key_mapping + + print("Begin parsing ABACUS output...") + h5_filenames = recursive_parse(**self.abacus_args) + self.file_name = h5_filenames + print("Finished parsing ABACUS output.") + + super(ABACUSInMemoryDataset, self).__init__( + file_name=self.file_name, + url=url, + root=root, + AtomicData_options=AtomicData_options, + include_frames=include_frames, + type_mapper=type_mapper, + ) + + def get_data(self): + data = [] + for h5_file in self.file_name: + data.append(_abacus_h5_reader(h5_file, self.AtomicData_options)) + return data + + @property + def raw_file_names(self): + return "AtomicData.h5" + + @property + def raw_dir(self): + return self.root \ No newline at end of file diff --git a/dptb/data/dataset/_ase_dataset.py b/dptb/data/dataset/_ase_dataset.py new file mode 100644 index 00000000..6200ebe5 --- /dev/null +++ b/dptb/data/dataset/_ase_dataset.py @@ -0,0 +1,238 @@ +import tempfile +import functools +import itertools +from os.path import dirname, basename, abspath +from typing import Dict, Any, List, Union, Optional, Sequence + +import ase +import ase.io + +import torch +import torch.multiprocessing as mp + + +from dptb.utils.multiprocessing import num_tasks +from .. import AtomicData +from ..transforms import TypeMapper +from ._base_datasets import AtomicInMemoryDataset + + +def _ase_dataset_reader( + rank: int, + world_size: int, + tmpdir: str, + ase_kwargs: dict, + atomicdata_kwargs: dict, + include_frames, + global_options: dict, +) -> Union[str, List[AtomicData]]: + """Parallel reader for all frames in file.""" + if world_size > 1: + from nequip.utils._global_options import _set_global_options + + # ^ avoid import loop + # we only `multiprocessing` if world_size > 1 + _set_global_options(global_options) + # interleave--- in theory it is better for performance for the ranks + # to read consecutive blocks, but the way ASE is written the whole + # file gets streamed through all ranks anyway, so just trust the OS + # to cache things sanely, which it will. + # ASE handles correctly the case where there are no frames in index + # and just gives an empty list, so that will succeed: + index = slice(rank, None, world_size) + if include_frames is None: + # count includes 0, 1, ..., inf + include_frames = itertools.count() + + datas = [] + # stream them from ase too using iread + for i, atoms in enumerate(ase.io.iread(**ase_kwargs, index=index, parallel=False)): + global_index = rank + (world_size * i) + datas.append( + ( + global_index, + AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs) + if global_index in include_frames + # in-memory dataset will ignore this later, but needed for indexing to work out + else None, + ) + ) + # Save to a tempfile--- + # there can be a _lot_ of tensors here, and rather than dealing with + # the complications of running out of file descriptors and setting + # sharing methods, since this is a one time thing, just make it simple + # and avoid shared memory entirely. + if world_size > 1: + path = f"{tmpdir}/rank{rank}.pth" + torch.save(datas, path) + return path + else: + return datas + + +class ASEDataset(AtomicInMemoryDataset): + """ + + Args: + ase_args (dict): arguments for ase.io.read + include_keys (list): in addition to forces and energy, the keys that needs to + be parsed into dataset + The data stored in ase.atoms.Atoms.array has the lowest priority, + and it will be overrided by data in ase.atoms.Atoms.info + and ase.atoms.Atoms.calc.results. Optional + key_mapping (dict): rename some of the keys to the value str. Optional + + Example: Given an atomic data stored in "H2.extxyz" that looks like below: + + ```H2.extxyz + 2 + Properties=species:S:1:pos:R:3 energy=-10 user_label=2.0 pbc="F F F" + H 0.00000000 0.00000000 0.00000000 + H 0.00000000 0.00000000 1.02000000 + ``` + + The yaml input should be + + ``` + dataset: ase + dataset_file_name: H2.extxyz + ase_args: + format: extxyz + include_keys: + - user_label + key_mapping: + user_label: label0 + chemical_symbols: + - H + ``` + + for VASP parser, the yaml input should be + ``` + dataset: ase + dataset_file_name: OUTCAR + ase_args: + format: vasp-out + key_mapping: + free_energy: total_energy + chemical_symbols: + - H + ``` + + """ + + def __init__( + self, + root: str, + ase_args: dict = {}, + file_name: Optional[str] = None, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + key_mapping: Optional[dict] = None, + include_keys: Optional[List[str]] = None, + ): + self.ase_args = {} + self.ase_args.update(getattr(type(self), "ASE_ARGS", dict())) + self.ase_args.update(ase_args) + assert "index" not in self.ase_args + assert "filename" not in self.ase_args + + self.include_keys = include_keys + self.key_mapping = key_mapping + + super().__init__( + file_name=file_name, + url=url, + root=root, + AtomicData_options=AtomicData_options, + include_frames=include_frames, + type_mapper=type_mapper, + ) + + @classmethod + def from_atoms_list(cls, atoms: Sequence[ase.Atoms], **kwargs): + """Make an ``ASEDataset`` from a list of ``ase.Atoms`` objects. + + If `root` is not provided, a temporary directory will be used. + + Please note that this is a convinience method that does NOT avoid a round-trip to disk; the provided ``atoms`` will be written out to a file. + + Ignores ``kwargs["file_name"]`` if it is provided. + + Args: + atoms + **kwargs: passed through to the constructor + Returns: + The constructed ``ASEDataset``. + """ + if "root" not in kwargs: + tmpdir = tempfile.TemporaryDirectory() + kwargs["root"] = tmpdir.name + else: + tmpdir = None + kwargs["file_name"] = tmpdir.name + "/atoms.xyz" + atoms = list(atoms) + # Write them out + ase.io.write(kwargs["file_name"], atoms, format="extxyz") + # Read them in + obj = cls(**kwargs) + if tmpdir is not None: + # Make it keep a reference to the tmpdir to keep it alive + # When the dataset is garbage collected, the tmpdir will + # be too, and will (hopefully) get deleted eventually. + # Or at least by end of program... + obj._tmpdir_ref = tmpdir + return obj + + @property + def raw_file_names(self): + return [basename(self.file_name)] + + @property + def raw_dir(self): + return dirname(abspath(self.file_name)) + + def get_data(self): + ase_args = {"filename": self.raw_dir + "/" + self.raw_file_names[0]} + ase_args.update(self.ase_args) + + # skip the None arguments + kwargs = dict( + include_keys=self.include_keys, + key_mapping=self.key_mapping, + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + kwargs.update(self.AtomicData_options) + n_proc = num_tasks() + with tempfile.TemporaryDirectory() as tmpdir: + from nequip.utils._global_options import _get_latest_global_options + + # ^ avoid import loop + reader = functools.partial( + _ase_dataset_reader, + world_size=n_proc, + tmpdir=tmpdir, + ase_kwargs=ase_args, + atomicdata_kwargs=kwargs, + include_frames=self.include_frames, + # get the global options of the parent to initialize the worker correctly + global_options=_get_latest_global_options(), + ) + if n_proc > 1: + # things hang for some obscure OpenMP reason on some systems when using `fork` method + ctx = mp.get_context("forkserver") + with ctx.Pool(processes=n_proc) as p: + # map it over the `rank` argument + datas = p.map(reader, list(range(n_proc))) + # clean up the pool before loading the data + datas = [torch.load(d) for d in datas] + datas = sum(datas, []) + # un-interleave the datas + datas = sorted(datas, key=lambda e: e[0]) + else: + datas = reader(rank=0) + # datas here is already in order, stride 1 start 0 + # no need to un-interleave + # return list of AtomicData: + return [e[1] for e in datas] diff --git a/dptb/data/dataset/_base_datasets.py b/dptb/data/dataset/_base_datasets.py new file mode 100644 index 00000000..8c43c1c0 --- /dev/null +++ b/dptb/data/dataset/_base_datasets.py @@ -0,0 +1,665 @@ +import numpy as np +import logging +import inspect +import itertools +import yaml +import hashlib +import math +from typing import Tuple, Dict, Any, List, Callable, Union, Optional + +import torch + +from torch_runstats.scatter import scatter_std, scatter_mean + +from dptb.utils.torch_geometric import Batch, Dataset +from dptb.utils.tools import download_url, extract_zip + +import dptb +from dptb.data import ( + AtomicData, + AtomicDataDict, + _NODE_FIELDS, + _EDGE_FIELDS, + _GRAPH_FIELDS, +) +from dptb.utils.batch_ops import bincount +from dptb.utils.regressor import solver +from dptb.utils.savenload import atomic_write +from ..transforms import TypeMapper + + +class AtomicDataset(Dataset): + """The base class for all NequIP datasets.""" + + root: str + dtype: torch.dtype + + def __init__( + self, + root: str, + type_mapper: Optional[TypeMapper] = None, + ): + self.dtype = torch.get_default_dtype() + super().__init__(root=root, transform=type_mapper) + + def statistics( + self, + fields: List[Union[str, Callable]], + modes: List[str], + stride: int = 1, + unbiased: bool = True, + kwargs: Optional[Dict[str, dict]] = {}, + ) -> List[tuple]: + # TODO: If needed, this can eventually be implimented for general AtomicDataset by computing an online running mean and using Welford's method for a stable running standard deviation: https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/ + # That would be needed if we have lazy loading datasets. + # TODO: When lazy-loading datasets are implimented, how to deal with statistics, sampling, and subsets? + raise NotImplementedError("not implimented for general AtomicDataset yet") + + @property + def type_mapper(self) -> Optional[TypeMapper]: + # self.transform is always a TypeMapper + return self.transform + + def _get_parameters(self) -> Dict[str, Any]: + """Get a dict of the parameters used to build this dataset.""" + pnames = list(inspect.signature(self.__init__).parameters) + IGNORE_KEYS = { + # the type mapper is applied after saving, not before, so doesn't matter to cache validity + "type_mapper" + } + params = { + k: getattr(self, k) + for k in pnames + if k not in IGNORE_KEYS and hasattr(self, k) + } + # Add other relevant metadata: + params["dtype"] = str(self.dtype) + params["nequip_version"] = dptb.__version__ + + return params + + @property + def processed_dir(self) -> str: + # We want the file name to change when the parameters change + # So, first we get all parameters: + params = self._get_parameters() + # Make some kind of string of them: + # we don't care about this possibly changing between python versions, + # since a change in python version almost certainly means a change in + # versions of other things too, and is a good reason to recompute + buffer = yaml.dump(params).encode("ascii") + # And hash it: + param_hash = hashlib.sha1(buffer).hexdigest() + return f"{self.root}/processed_dataset_{param_hash}" + + +class AtomicInMemoryDataset(AtomicDataset): + r"""Base class for all datasets that fit in memory. + + Please note that, as a ``pytorch_geometric`` dataset, it must be backed by some kind of disk storage. + By default, the raw file will be stored at root/raw and the processed torch + file will be at root/process. + + Subclasses must implement: + - ``raw_file_names`` + - ``get_data()`` + + Subclasses may implement: + - ``download()`` or ``self.url`` or ``ClassName.URL`` + + Args: + root (str, optional): Root directory where the dataset should be saved. Defaults to current working directory. + file_name (str, optional): file name of data source. only used in children class + url (str, optional): url to download data source + AtomicData_options (dict, optional): extra key that are not stored in data but needed for AtomicData initialization + include_frames (list, optional): the frames to process with the constructor. + type_mapper (TypeMapper): the transformation to map atomic information to species index. Optional + """ + + def __init__( + self, + root: str, + file_name: Optional[str] = None, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: Optional[TypeMapper] = None, + ): + # TO DO, this may be simplified + # See if a subclass defines some inputs + self.file_name = ( + getattr(type(self), "FILE_NAME", None) if file_name is None else file_name + ) + self.url = getattr(type(self), "URL", url) + + self.AtomicData_options = AtomicData_options + self.include_frames = include_frames + + self.data = None + + # !!! don't delete this block. + # otherwise the inherent children class + # will ignore the download function here + class_type = type(self) + if class_type != AtomicInMemoryDataset: + if "download" not in self.__class__.__dict__: + class_type.download = AtomicInMemoryDataset.download + if "process" not in self.__class__.__dict__: + class_type.process = AtomicInMemoryDataset.process + + # Initialize the InMemoryDataset, which runs download and process + # See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets + # Then pre-process the data if disk files are not found + super().__init__(root=root, type_mapper=type_mapper) + if self.data is None: + self.data, include_frames = torch.load(self.processed_paths[0]) + if not np.all(include_frames == self.include_frames): + raise ValueError( + f"the include_frames is changed. " + f"please delete the processed folder and rerun {self.processed_paths[0]}" + ) + + def len(self): + if self.data is None: + return 0 + return self.data.num_graphs + + @property + def raw_file_names(self): + raise NotImplementedError() + + @property + def processed_file_names(self) -> List[str]: + return ["data.pth", "params.yaml"] + + def get_data( + self, + ) -> Union[Tuple[Dict[str, Any], Dict[str, Any]], List[AtomicData]]: + """Get the data --- called from ``process()``, can assume that ``raw_file_names()`` exist. + + Note that parameters for graph construction such as ``pbc`` and ``r_max`` should be included here as (likely, but not necessarily, fixed) fields. + + Returns: + A dict: + fields: dict + mapping a field name ('pos', 'cell') to a list-like sequence of tensor-like objects giving that field's value for each example. + Or: + data_list: List[AtomicData] + """ + raise NotImplementedError + + def download(self): + if (not hasattr(self, "url")) or (self.url is None): + # Don't download, assume present. Later could have FileNotFound if the files don't actually exist + pass + else: + download_path = download_url(self.url, self.raw_dir) + if download_path.endswith(".zip"): + extract_zip(download_path, self.raw_dir) + + def process(self): + data = self.get_data() ## get data returns either a list of AtomicData class or a data dict + if isinstance(data, list): + + # It's a data list + data_list = data + if not (self.include_frames is None or data_list is None): + data_list = [data_list[i] for i in self.include_frames] # 可以选择数据集中加载的序号 + assert all(isinstance(e, AtomicData) for e in data_list) + assert all(AtomicDataDict.BATCH_KEY not in e for e in data_list) + + fields = {} + + elif isinstance(data, dict): + # It's fields + # Get our data + fields = data + + # check keys + all_keys = set(fields.keys()) + assert AtomicDataDict.BATCH_KEY not in all_keys + # Check bad key combinations, but don't require that this be a graph yet. + AtomicDataDict.validate_keys(all_keys, graph_required=False) + + # check dimesionality + num_examples = set([len(a) for a in fields.values()]) + if not len(num_examples) == 1: + raise ValueError( + f"This dataset is invalid: expected all fields to have same length (same number of examples), but they had shapes { {f: v.shape for f, v in fields.items() } }" + ) + num_examples = next(iter(num_examples)) + + include_frames = self.include_frames + if include_frames is None: + include_frames = range(num_examples) + + # Make AtomicData from it: + if AtomicDataDict.EDGE_INDEX_KEY in all_keys: + # This is already a graph, just build it + constructor = AtomicData + else: + # do neighborlist from points + constructor = AtomicData.from_points + assert "r_max" in self.AtomicData_options + assert AtomicDataDict.POSITIONS_KEY in all_keys + + data_list = [ + constructor( + **{ + **{f: v[i] for f, v in fields.items()}, + **self.AtomicData_options, + } + ) + for i in include_frames + ] + + + else: + raise ValueError("Invalid return from `self.get_data()`") + + # Batch it for efficient saving + # This limits an AtomicInMemoryDataset to a maximum of LONG_MAX atoms _overall_, but that is a very big number and any dataset that large is probably not "InMemory" anyway + data = Batch.from_data_list(data_list) + del data_list + del fields + + total_MBs = sum(item.numel() * item.element_size() for _, item in data) / ( + 1024 * 1024 + ) + logging.info( + f"Loaded data: {data}\n processed data size: ~{total_MBs:.2f} MB" + ) + del total_MBs + + # use atomic writes to avoid race conditions between + # different trainings that use the same dataset + # since those separate trainings should all produce the same results, + # it doesn't matter if they overwrite each others cached' + # datasets. It only matters that they don't simultaneously try + # to write the _same_ file, corrupting it. + with atomic_write(self.processed_paths[0], binary=True) as f: + torch.save((data, self.include_frames), f) + with atomic_write(self.processed_paths[1], binary=False) as f: + yaml.dump(self._get_parameters(), f) + + logging.info("Cached processed data to disk") + + self.data = data + + def get(self, idx): + return self.data.get_example(idx) + + def _selectors( + self, + stride: int = 1, + ): + if self._indices is not None: + graph_selector = torch.as_tensor(self._indices)[::stride] + # note that self._indices is _not_ necessarily in order, + # while self.data --- which we take our arrays from --- + # is always in the original order. + # In particular, the values of `self.data.batch` + # are indexes in the ORIGINAL order + # thus we need graph level properties to also be in the original order + # so that batch values index into them correctly + # since self.data.batch is always sorted & contiguous + # (because of Batch.from_data_list) + # we sort it: + graph_selector, _ = torch.sort(graph_selector) + else: + graph_selector = torch.arange(0, self.len(), stride) + + node_selector = torch.as_tensor( + np.in1d(self.data.batch.numpy(), graph_selector.numpy()) + ) + + edge_index = self.data[AtomicDataDict.EDGE_INDEX_KEY] + edge_selector = node_selector[edge_index[0]] & node_selector[edge_index[1]] + + return (graph_selector, node_selector, edge_selector) + + def statistics( + self, + fields: List[Union[str, Callable]], + modes: List[str], + stride: int = 1, + unbiased: bool = True, + kwargs: Optional[Dict[str, dict]] = {}, + ) -> List[tuple]: + """Compute the statistics of ``fields`` in the dataset. + + If the values at the fields are vectors/multidimensional, they must be of fixed shape and elementwise statistics will be computed. + + Args: + fields: the names of the fields to compute statistics for. + Instead of a field name, a callable can also be given that reuturns a quantity to compute the statisics for. + + If a callable is given, it will be called with a (possibly batched) ``Data``-like object and must return a sequence of points to add to the set over which the statistics will be computed. + The callable must also return a string, one of ``"node"`` or ``"graph"``, indicating whether the value it returns is a per-node or per-graph quantity. + PLEASE NOTE: the argument to the callable may be "batched", and it may not be batched "contiguously": ``batch`` and ``edge_index`` may have "gaps" in their values. + + For example, to compute the overall statistics of the x,y, and z components of a per-node vector ``force`` field: + + data.statistics([lambda data: (data.force.flatten(), "node")]) + + The above computes the statistics over a set of size 3N, where N is the total number of nodes in the dataset. + + modes: the statistic to compute for each field. Valid options are: + - ``count`` + - ``rms`` + - ``mean_std`` + - ``per_atom_*`` + - ``per_species_*`` + + stride: the stride over the dataset while computing statistcs. + + unbiased: whether to use unbiased for standard deviations. + + kwargs: other options for individual statistics modes. + + Returns: + List of statistics. For fields of floating dtype the statistics are the two-tuple (mean, std); for fields of integer dtype the statistics are a one-tuple (bincounts,) + """ + + # Short circut: + assert len(modes) == len(fields) + if len(fields) == 0: + return [] + + graph_selector, node_selector, edge_selector = self._selectors(stride=stride) + + num_graphs = len(graph_selector) + num_nodes = node_selector.sum() + num_edges = edge_selector.sum() + + if self.transform is not None: + # pre-transform the data so that statistics process transformed data + data_transformed = self.transform(self.data.to_dict(), types_required=False) + else: + data_transformed = self.data.to_dict() + # pre-select arrays + # this ensures that all following computations use the right data + all_keys = set() + selectors = {} + for k in data_transformed.keys(): + all_keys.add(k) + if k in _NODE_FIELDS: + selectors[k] = node_selector + elif k in _GRAPH_FIELDS: + selectors[k] = graph_selector + elif k == AtomicDataDict.EDGE_INDEX_KEY: + selectors[k] = (slice(None, None, None), edge_selector) + elif k in _EDGE_FIELDS: + selectors[k] = edge_selector + # TODO: do the batch indexes, edge_indexes, etc. after selection need to be + # "compacted" to subtract out their offsets? For now, we just punt this + # onto the writer of the callable field. + # apply selector to actual data + data_transformed = { + k: data_transformed[k][selectors[k]] + for k in data_transformed.keys() + if k in selectors + } + + atom_types: Optional[torch.Tensor] = None + out: list = [] + for ifield, field in enumerate(fields): + if callable(field): + # make a joined thing? so it includes fixed fields + arr, arr_is_per = field(data_transformed) + arr = arr.to(self.dtype) # all statistics must be on floating + assert arr_is_per in ("node", "graph", "edge") + else: + if field not in all_keys: + raise RuntimeError( + f"The field key `{field}` is not present in this dataset" + ) + if field not in selectors: + # this means field is not selected and so not available + raise RuntimeError( + f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such using `nequip.data.register_fields`" + ) + arr = data_transformed[field] + if field in _NODE_FIELDS: + arr_is_per = "node" + elif field in _GRAPH_FIELDS: + arr_is_per = "graph" + elif field in _EDGE_FIELDS: + arr_is_per = "edge" + else: + raise RuntimeError + + # Check arr + if arr is None: + raise ValueError( + f"Cannot compute statistics over field `{field}` whose value is None!" + ) + if not isinstance(arr, torch.Tensor): + if np.issubdtype(arr.dtype, np.floating): + arr = torch.as_tensor(arr, dtype=self.dtype) + else: + arr = torch.as_tensor(arr) + if arr_is_per == "node": + arr = arr.view(num_nodes, -1) + elif arr_is_per == "graph": + arr = arr.view(num_graphs, -1) + elif arr_is_per == "edge": + arr = arr.view(num_edges, -1) + + ana_mode = modes[ifield] + # compute statistics + if ana_mode == "count": + # count integers + uniq, counts = torch.unique( + torch.flatten(arr), return_counts=True, sorted=True + ) + out.append((uniq, counts)) + elif ana_mode == "rms": + # root-mean-square + out.append((torch.sqrt(torch.mean(arr * arr)),)) + + elif ana_mode == "mean_std": + # mean and std + if len(arr) < 2: + raise ValueError( + "Can't do per species standard deviation without at least two samples" + ) + mean = torch.mean(arr, dim=0) + std = torch.std(arr, dim=0, unbiased=unbiased) + out.append((mean, std)) + + elif ana_mode == "absmax": + out.append((arr.abs().max(),)) + + elif ana_mode.startswith("per_species_"): + # per-species + algorithm_kwargs = kwargs.pop(field + ana_mode, {}) + + ana_mode = ana_mode[len("per_species_") :] + + if atom_types is None: + atom_types = data_transformed[AtomicDataDict.ATOM_TYPE_KEY] + + results = self._per_species_statistics( + ana_mode, + arr, + arr_is_per=arr_is_per, + batch=data_transformed[AtomicDataDict.BATCH_KEY], + atom_types=atom_types, + unbiased=unbiased, + algorithm_kwargs=algorithm_kwargs, + ) + out.append(results) + + elif ana_mode.startswith("per_atom_"): + # per-atom + # only makes sense for a per-graph quantity + if arr_is_per != "graph": + raise ValueError( + f"It doesn't make sense to ask for `{ana_mode}` since `{field}` is not per-graph" + ) + ana_mode = ana_mode[len("per_atom_") :] + results = self._per_atom_statistics( + ana_mode=ana_mode, + arr=arr, + batch=data_transformed[AtomicDataDict.BATCH_KEY], + unbiased=unbiased, + ) + out.append(results) + + else: + raise NotImplementedError(f"Cannot handle statistics mode {ana_mode}") + + return out + + @staticmethod + def _per_atom_statistics( + ana_mode: str, + arr: torch.Tensor, + batch: torch.Tensor, + unbiased: bool = True, + ): + """Compute "per-atom" statistics that are normalized by the number of atoms in the system. + + Only makes sense for a graph-level quantity (checked by .statistics). + """ + # using unique_consecutive handles the non-contiguous selected batch index + _, N = torch.unique_consecutive(batch, return_counts=True) + N = N.unsqueeze(-1) + assert N.ndim == 2 + assert N.shape == (len(arr), 1) + assert arr.ndim >= 2 + data_dim = arr.shape[1:] + arr = arr / N + assert arr.shape == (len(N),) + data_dim + if ana_mode == "mean_std": + if len(arr) < 2: + raise ValueError( + "Can't do standard deviation without at least two samples" + ) + mean = torch.mean(arr, dim=0) + std = torch.std(arr, unbiased=unbiased, dim=0) + return mean, std + elif ana_mode == "rms": + return (torch.sqrt(torch.mean(arr.square())),) + elif ana_mode == "absmax": + return (torch.max(arr.abs()),) + else: + raise NotImplementedError( + f"{ana_mode} for per-atom analysis is not implemented" + ) + + def _per_species_statistics( + self, + ana_mode: str, + arr: torch.Tensor, + arr_is_per: str, + atom_types: torch.Tensor, + batch: torch.Tensor, + unbiased: bool = True, + algorithm_kwargs: Optional[dict] = {}, + ): + """Compute "per-species" statistics. + + For a graph-level quantity, models it as a linear combintation of the number of atoms of different types in the graph. + + For a per-node quantity, computes the expected statistic but for each type instead of over all nodes. + """ + N = bincount(atom_types.squeeze(-1), batch) + assert N.ndim == 2 # [batch, n_type] + N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes + assert arr.ndim >= 2 + if arr_is_per == "graph": + + if ana_mode != "mean_std": + raise NotImplementedError( + f"{ana_mode} for per species analysis is not implemented for shape {arr.shape}" + ) + + N = N.type(self.dtype) + + return solver(N, arr, **algorithm_kwargs) + + elif arr_is_per == "node": + arr = arr.type(self.dtype) + + if ana_mode == "mean_std": + # There need to be at least two occurances of each atom type in the + # WHOLE dataset, not in any given frame: + if torch.any(N.sum(dim=0) < 2): + raise ValueError( + "Can't do per species standard deviation without at least two samples per species" + ) + mean = scatter_mean(arr, atom_types, dim=0) + assert mean.shape[1:] == arr.shape[1:] # [N, dims] -> [type, dims] + assert len(mean) == N.shape[1] + std = scatter_std(arr, atom_types, dim=0, unbiased=unbiased) + assert std.shape == mean.shape + return mean, std + elif ana_mode == "rms": + square = scatter_mean(arr.square(), atom_types, dim=0) + assert square.shape[1:] == arr.shape[1:] # [N, dims] -> [type, dims] + assert len(square) == N.shape[1] + dims = len(square.shape) - 1 + for i in range(dims): + square = square.mean(axis=-1) + return (torch.sqrt(square),) + else: + raise NotImplementedError( + f"Statistics mode {ana_mode} isn't yet implemented for per_species_" + ) + + else: + raise NotImplementedError + + def rdf( + self, bin_width: float, stride: int = 1 + ) -> Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]: + """Compute the pairwise RDFs of the dataset. + + Args: + bin_width: width of the histogram bin in distance units + stride: stride of data to include + + Returns: + dictionary mapping `(type1, type2)` to tuples of `(hist, bin_edges)` in the style of `np.histogram`. + """ + graph_selector, node_selector, edge_selector = self._selectors(stride=stride) + + data = AtomicData.to_AtomicDataDict(self.data) + data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) + + results = {} + + types = self.type_mapper(data)[AtomicDataDict.ATOM_TYPE_KEY] + + edge_types = torch.index_select( + types, 0, data[AtomicDataDict.EDGE_INDEX_KEY].reshape(-1) + ).view(2, -1) + types_center = edge_types[0].numpy() + types_neigh = edge_types[1].numpy() + + r_max: float = self.AtomicData_options["r_max"] + # + 1 to always have a zero bin at the end + n_bins: int = int(math.ceil(r_max / bin_width)) + 1 + # +1 since these are bin_edges including rightmost + bins = bin_width * np.arange(n_bins + 1) + + for type1, type2 in itertools.combinations_with_replacement( + range(self.type_mapper.num_types), 2 + ): + # Try to do as much of this as possible in-place + mask = types_center == type1 + np.logical_and(mask, types_neigh == type2, out=mask) + np.logical_and(mask, edge_selector, out=mask) + mask = mask.astype(np.int32) + results[(type1, type2)] = np.histogram( + data[AtomicDataDict.EDGE_LENGTH_KEY], + weights=mask, + bins=bins, + density=True, + ) + # RDF is symmetric + results[(type2, type1)] = results[(type1, type2)] + + return results diff --git a/dptb/data/dataset/_deeph_dataset.py b/dptb/data/dataset/_deeph_dataset.py new file mode 100644 index 00000000..2485ce1b --- /dev/null +++ b/dptb/data/dataset/_deeph_dataset.py @@ -0,0 +1,79 @@ +from typing import Dict, Any, List, Callable, Union, Optional +import os +import numpy as np +import h5py + +import torch + +from .. import ( + AtomicData, + AtomicDataDict, +) +from ..transforms import TypeMapper, OrbitalMapper +from ._base_datasets import AtomicDataset +from dptb.nn.hamiltonian import E3Hamiltonian +from dptb.data.interfaces.ham_to_feature import openmx_to_deeptb + +orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"} + +class DeePHE3Dataset(AtomicDataset): + + def __init__( + self, + root: str, + key_mapping: Dict[str, str] = { + "pos": AtomicDataDict.POSITIONS_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + "atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY, + "kpoints": AtomicDataDict.KPOINT_KEY, + "eigenvalues": AtomicDataDict.ENERGY_EIGENVALUE_KEY, + }, + preprocess_path: str = None, + subdir_names: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + type_mapper: Optional[TypeMapper] = None, + ): + super().__init__(root=root, type_mapper=type_mapper) + self.key_mapping = key_mapping + self.key_list = list(key_mapping.keys()) + self.value_list = list(key_mapping.values()) + self.subdir_names = subdir_names + self.preprocess_path = preprocess_path + + self.AtomicData_options = AtomicData_options + # self.r_max = AtomicData_options["r_max"] + # self.er_max = AtomicData_options["er_max"] + # self.oer_max = AtomicData_options["oer_max"] + # self.pbc = AtomicData_options["pbc"] + + self.index = None + self.num_examples = len(subdir_names) + + def get(self, idx): + file_name = self.subdir_names[idx] + file = os.path.join(self.preprocess_path, file_name) + + if os.path.exists(os.path.join(file, "AtomicData.pth")): + atomic_data = torch.load(os.path.join(file, "AtomicData.pth")) + else: + atomic_data = AtomicData.from_points( + pos = np.loadtxt(os.path.join(file, "site_positions.dat")).T, + cell = np.loadtxt(os.path.join(file, "lat.dat")).T, + atomic_numbers = np.loadtxt(os.path.join(file, "element.dat")), + **self.AtomicData_options, + ) + + idp = self.type_mapper + # e3 = E3Hamiltonian(idp=idp, decompose=True) + + openmx_to_deeptb(atomic_data, idp, os.path.join(file, "./hamiltonians.h5")) + # with torch.no_grad(): + # atomic_data = e3(atomic_data.to_dict()) + # atomic_data = AtomicData.from_dict(atomic_data) + + torch.save(atomic_data, os.path.join(file, "AtomicData.pth")) + + return atomic_data + + def len(self) -> int: + return self.num_examples \ No newline at end of file diff --git a/dptb/data/dataset/_default_dataset.py b/dptb/data/dataset/_default_dataset.py new file mode 100644 index 00000000..4a0d7f30 --- /dev/null +++ b/dptb/data/dataset/_default_dataset.py @@ -0,0 +1,425 @@ +from typing import Dict, Any, List, Callable, Union, Optional +import os +import glob + +import numpy as np +import h5py +from ase import Atoms +from ase.io import Trajectory + +import torch + +from .. import ( + AtomicData, + AtomicDataDict, +) +from ..transforms import TypeMapper, OrbitalMapper +from ._base_datasets import AtomicDataset, AtomicInMemoryDataset +#from dptb.nn.hamiltonian import E3Hamiltonian +from dptb.data.interfaces.ham_to_feature import ham_block_to_feature +from dptb.utils.tools import j_loader +from dptb.data.AtomicDataDict import with_edge_vectors +from dptb.nn.hamiltonian import E3Hamiltonian + +class _TrajData(object): + ''' + Input files format in a trajectory (shape): + "info.json": optional, includes infomation in the data files. + can be provided in the base (upper level) folder, or assign in each trajectory. + "cell.dat": fixed cell (3, 3) or variable cells (nframes, 3, 3). Unit: Angstrom + "atomic_numbers.dat": (natoms) or (nframes, natoms) + "positions.dat": concentrate all positions in one file, (nframes * natoms, 3). Can be cart or frac. + + Optional data files: + "eigenvalues.npy": concentrate all engenvalues in one file, (nframes, nkpoints, nbands) + "kpoints.npy": MUST be provided when loading `eigenvalues.npy`, (nkpoints, 3) or (nframes, nkpints, 3) + "hamiltonians.h5": h5 file storing atom-wise hamiltonian blocks labeled by frames id and `i0_jR_Rx_Ry_Rz`. + "overlaps.h5": the same format of overlap blocks as `hamiltonians.h5` + ''' + + def __init__(self, + root: str, + AtomicData_options: Dict[str, Any] = {}, + get_Hamiltonian = False, + get_eigenvalues = False, + info = None, + _clear = False): + self.root = root + self.AtomicData_options = AtomicData_options + self.info = info + + self.data = {} + # load cell + cell = np.loadtxt(os.path.join(root, "cell.dat")) + if cell.shape[0] == 3: + # same cell size, then copy it to all frames. + cell = np.expand_dims(cell, axis=0) + self.data["cell"] = np.broadcast_to(cell, (self.info["nframes"], 3, 3)) + elif cell.shape[0] == self.info["nframes"] * 3: + self.data["cell"] = cell.reshape(self.info["nframes"], 3, 3) + else: + raise ValueError("Wrong cell dimensions.") + + # load atomic numbers + atomic_numbers = np.loadtxt(os.path.join(root, "atomic_numbers.dat")) + if atomic_numbers.shape[0] == self.info["natoms"]: + # same atomic_numbers, copy it to all frames. + atomic_numbers = np.expand_dims(atomic_numbers, axis=0) + self.data["atomic_numbers"] = np.broadcast_to(atomic_numbers, (self.info["nframes"], + self.info["natoms"])) + elif atomic_numbers.shape[0] == self.info["natoms"] * self.info["nframes"]: + self.data["atomic_numbers"] = atomic_numbers.reshape(self.info["nframes"], + self.info["natoms"]) + else: + raise ValueError("Wrong atomic_number dimensions.") + + # load positions, stored as cartesion no matter what provided. + pos = np.loadtxt(os.path.join(root, "positions.dat")) + assert pos.shape[0] == self.info["nframes"] * self.info["natoms"] + pos = pos.reshape(self.info["nframes"], self.info["natoms"], 3) + # ase use cartesian by default. + if self.info["pos_type"] == "cart" or self.info["pos_type"] == "ase": + self.data["pos"] = pos + elif self.info["pos_type"] == "frac": + self.data["pos"] = pos @ self.data["cell"] + else: + raise NameError("Position type must be cart / frac.") + + # load optional data files + if os.path.exists(os.path.join(self.root, "eigenvalues.npy")) and get_eigenvalues==True: + assert "bandinfo" in self.info, "`bandinfo` must be provided in `info.json` for loading eigenvalues." + assert os.path.exists(os.path.join(self.root, "kpoints.npy")) + kpoints = np.load(os.path.join(self.root, "kpoints.npy")) + if kpoints.ndim == 2: + # only one frame or same kpoints, then copy it to all frames. + # shape: (nkpoints, 3) + kpoints = np.expand_dims(kpoints, axis=0) + self.data["kpoints"] = np.broadcast_to(kpoints, (self.info["nframes"], + kpoints.shape[1], 3)) + elif kpoints.ndim == 3 and kpoints.shape[0] == self.info["nframes"]: + # array of kpoints, (nframes, nkpoints, 3) + self.data["kpoints"] = kpoints + else: + raise ValueError("Wrong kpoint dimensions.") + eigenvalues = np.load(os.path.join(self.root, "eigenvalues.npy")) + # special case: trajectory contains only one frame + if eigenvalues.ndim == 2: + eigenvalues = np.expand_dims(eigenvalues, axis=0) + assert eigenvalues.shape[0] == self.info["nframes"] + assert eigenvalues.shape[1] == self.data["kpoints"].shape[1] + self.data["eigenvalues"] = eigenvalues + if os.path.exists(os.path.join(self.root, "hamiltonians.h5")) and get_Hamiltonian==True: + self.data["hamiltonian_blocks"] = h5py.File(os.path.join(self.root, "hamiltonians.h5"), "r") + if os.path.exists(os.path.join(self.root, "overlaps.h5")): + self.data["overlap_blocks"] = h5py.File(os.path.join(self.root, "overlaps.h5"), "r") + + # this is used to clear the tmp files to load ase trajectory only. + if _clear: + os.remove(os.path.join(root, "positions.dat")) + os.remove(os.path.join(root, "cell.dat")) + os.remove(os.path.join(root, "atomic_numbers.dat")) + + @classmethod + def from_ase_traj(cls, + root: str, + AtomicData_options: Dict[str, Any] = {}, + get_Hamiltonian = False, + get_eigenvalues = False, + info = None): + + traj_file = glob.glob(f"{root}/*.traj") + assert len(traj_file) == 1, print("only one ase trajectory file can be provided.") + traj = Trajectory(traj_file[0], 'r') + positions = [] + cell = [] + atomic_numbers = [] + for atoms in traj: + positions.append(atoms.get_positions()) + cell.append(atoms.get_cell()) + atomic_numbers.append(atoms.get_atomic_numbers()) + positions = np.array(positions) + positions = positions.reshape(-1, 3) + cell = np.array(cell) + cell = cell.reshape(-1, 3) + atomic_numbers = np.array(atomic_numbers) + atomic_numbers = atomic_numbers.reshape(-1, 1) + np.savetxt(os.path.join(root, "positions.dat"), positions) + np.savetxt(os.path.join(root, "cell.dat"), cell) + np.savetxt(os.path.join(root, "atomic_numbers.dat"), atomic_numbers, fmt='%d') + + return cls(root=root, + AtomicData_options=AtomicData_options, + get_Hamiltonian=get_Hamiltonian, + get_eigenvalues=get_eigenvalues, + info=info, + _clear=True) + + def toAtomicDataList(self, idp: TypeMapper = None): + data_list = [] + for frame in range(self.info["nframes"]): + atomic_data = AtomicData.from_points( + pos = self.data["pos"][frame][:], + cell = self.data["cell"][frame][:], + atomic_numbers = self.data["atomic_numbers"][frame], + # pbc is stored in AtomicData_options now. + #pbc = self.info["pbc"], + **self.AtomicData_options) + if "hamiltonian_blocks" in self.data: + assert idp is not None, "LCAO Basis must be provided in `common_option` for loading Hamiltonian." + if "overlap_blocks" not in self.data: + self.data["overlap_blocks"] = [False] * self.info["nframes"] + # e3 = E3Hamiltonian(idp=idp, decompose=True) + ham_block_to_feature(atomic_data, idp, + self.data["hamiltonian_blocks"][str(frame+1)], + self.data["overlap_blocks"][str(frame+1)]) + + # TODO: initialize the edge and node feature tempretely, there should be a better way. + else: + # just temporarily initialize the edge and node feature to zeros, to let the batch collate work. + atomic_data[AtomicDataDict.EDGE_FEATURES_KEY] = torch.zeros(atomic_data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], 1) + atomic_data[AtomicDataDict.NODE_FEATURES_KEY] = torch.zeros(atomic_data[AtomicDataDict.POSITIONS_KEY].shape[0], 1) + atomic_data[AtomicDataDict.EDGE_OVERLAP_KEY] = torch.zeros(atomic_data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], 1) + # with torch.no_grad(): + # atomic_data = e3(atomic_data.to_dict()) + # atomic_data = AtomicData.from_dict(atomic_data) + if "eigenvalues" in self.data and "kpoints" in self.data: + assert "bandinfo" in self.info, "`bandinfo` must be provided in `info.json` for loading eigenvalues." + bandinfo = self.info["bandinfo"] + atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(self.data["kpoints"][frame][:], + dtype=torch.get_default_dtype()) + if bandinfo["emin"] is not None and bandinfo["emax"] is not None: + atomic_data[AtomicDataDict.ENERGY_WINDOWS_KEY] = torch.as_tensor([bandinfo["emin"], bandinfo["emax"]], + dtype=torch.get_default_dtype()) + if bandinfo["band_min"] is not None and bandinfo["band_max"] is not None: + atomic_data[AtomicDataDict.BAND_WINDOW_KEY] = torch.as_tensor([bandinfo["band_min"], bandinfo["band_max"]], + dtype=torch.long) + # atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(self.data["eigenvalues"][frame][:, bandinfo["band_min"]:bandinfo["band_max"]], + # dtype=torch.get_default_dtype()) + atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(self.data["eigenvalues"][frame], + dtype=torch.get_default_dtype()) + data_list.append(atomic_data) + return data_list + + +class DefaultDataset(AtomicInMemoryDataset): + + def __init__( + self, + root: str, + info_files: Dict[str, Dict], + url: Optional[str] = None, # seems useless but can't be remove + include_frames: Optional[List[int]] = None, # maybe support in future + type_mapper: TypeMapper = None, + get_Hamiltonian: bool = False, + get_eigenvalues: bool = False, + ): + self.root = root + self.url = url + self.info_files = info_files + # The following flags are stored to label dataset. + self.get_Hamiltonian = get_Hamiltonian + self.get_eigenvalues = get_eigenvalues + + # load all data files + self.raw_data = [] + for file in self.info_files.keys(): + # get the info here + info = info_files[file] + assert "AtomicData_options" in info + AtomicData_options = info["AtomicData_options"] + assert "r_max" in AtomicData_options + assert "pbc" in AtomicData_options + if info["pos_type"] == "ase": + subdata = _TrajData.from_ase_traj(os.path.join(self.root, file), + AtomicData_options, + get_Hamiltonian, + get_eigenvalues, + info=info) + else: + subdata = _TrajData(os.path.join(self.root, file), + AtomicData_options, + get_Hamiltonian, + get_eigenvalues, + info=info) + self.raw_data.append(subdata) + + # The AtomicData_options is never used here. + # Because we always return a list of AtomicData object in `get_data()`. + # That is, AtomicInMemoryDataset will not use AtomicData_options to build any AtomicData here. + super().__init__( + file_name=None, # this seems not important too. + url=url, + root=root, + AtomicData_options={}, # we do not pass anything here. + include_frames=include_frames, + type_mapper=type_mapper, + ) + + def get_data(self): + all_data = [] + for subdata in self.raw_data: + # the type_mapper here is loaded in PyG `dataset` type as `transform` attritube + # so the OrbitalMapper can be accessed by self.transform here + subdata_list = subdata.toAtomicDataList(self.transform) + all_data += subdata_list + return all_data + + @property + def raw_file_names(self): + # TODO: this is not implemented. + return "Null" + + @property + def raw_dir(self): + # TODO: this is not implemented. + return self.root + + def E3statistics(self, decay=False): + assert self.transform is not None + idp = self.transform + + if self.data[AtomicDataDict.EDGE_FEATURES_KEY].abs().sum() < 1e-7: + return None + + typed_dataset = idp(self.data.clone().to_dict()) + e3h = E3Hamiltonian(basis=idp.basis, decompose=True) + with torch.no_grad(): + typed_dataset = e3h(typed_dataset) + + stats = {} + stats["node"] = self._E3nodespecies_stat(typed_dataset=typed_dataset) + stats["edge"] = self._E3edgespecies_stat(typed_dataset=typed_dataset, decay=decay) + + return stats + + def _E3edgespecies_stat(self, typed_dataset, decay): + # we get the bond type marked dataset first + idp = self.transform + typed_dataset = typed_dataset + + idp.get_irreps(no_parity=False) + irrep_slices = idp.orbpair_irreps.slices() + + features = typed_dataset["edge_features"] + hopping_block_mask = idp.mask_to_erme[typed_dataset["edge_type"].flatten()] + typed_hopping = {} + for bt, tp in idp.bond_to_type.items(): + hopping_tp_mask = hopping_block_mask[typed_dataset["edge_type"].flatten().eq(tp)] + hopping_tp = features[typed_dataset["edge_type"].flatten().eq(tp)] + filtered_vecs = torch.where(hopping_tp_mask, hopping_tp, torch.tensor(float('nan'))) + typed_hopping[bt] = filtered_vecs + + sorted_irreps = idp.orbpair_irreps.sort()[0].simplify() + n_scalar = sorted_irreps[0].mul if sorted_irreps[0].ir.l == 0 else 0 + + # calculate norm & mean + typed_norm = {} + typed_norm_ave = torch.ones(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps) + typed_norm_std = torch.zeros(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps) + typed_scalar_ave = torch.ones(len(idp.bond_to_type), n_scalar) + typed_scalar_std = torch.zeros(len(idp.bond_to_type), n_scalar) + for bt, tp in idp.bond_to_type.items(): + norms_per_irrep = [] + count_scalar = 0 + for ir, s in enumerate(irrep_slices): + sub_tensor = typed_hopping[bt][:, s] + # dump the nan blocks here + if sub_tensor.shape[-1] == 1: + count_scalar += 1 + if not torch.isnan(sub_tensor).all(): + # update the mean and ave + norms = torch.norm(sub_tensor, p=2, dim=1) # shape: [n_edge] + if sub_tensor.shape[-1] == 1: + # it's a scalar + typed_scalar_ave[tp][count_scalar-1] = sub_tensor.mean() + typed_scalar_std[tp][count_scalar-1] = sub_tensor.std() + typed_norm_ave[tp][ir] = norms.mean() + typed_norm_std[tp][ir] = norms.std() + else: + norms = torch.ones_like(sub_tensor[:, 0]) + + if decay: + norms_per_irrep.append(norms) + + assert count_scalar <= n_scalar + # shape of typed_norm: (n_irreps, n_edges) + + if decay: + typed_norm[bt] = torch.stack(norms_per_irrep) + + edge_stats = { + "norm_ave": typed_norm_ave, + "norm_std": typed_norm_std, + "scalar_ave": typed_scalar_ave, + "scalar_std": typed_scalar_std, + } + + if decay: + typed_dataset = with_edge_vectors(typed_dataset) + decay = {} + for bt, tp in idp.bond_to_type.items(): + decay_bt = {} + lengths_bt = typed_dataset["edge_lengths"][typed_dataset["edge_type"].flatten().eq(tp)] + sorted_lengths, indices = lengths_bt.sort() # from small to large + # sort the norms by irrep l + sorted_norms = typed_norm[bt][idp.orbpair_irreps.sort().inv, :] + # sort the norms by edge length + sorted_norms = sorted_norms[:, indices] + decay_bt["edge_length"] = sorted_lengths + decay_bt["norm_decay"] = sorted_norms + decay[bt] = decay_bt + + edge_stats["decay"] = decay + + return edge_stats + + def _E3nodespecies_stat(self, typed_dataset): + # we get the type marked dataset first + idp = self.transform + typed_dataset = typed_dataset + + idp.get_irreps(no_parity=False) + irrep_slices = idp.orbpair_irreps.slices() + + sorted_irreps = idp.orbpair_irreps.sort()[0].simplify() + n_scalar = sorted_irreps[0].mul if sorted_irreps[0].ir.l == 0 else 0 + + features = typed_dataset["node_features"] + onsite_block_mask = idp.mask_to_nrme[typed_dataset["atom_types"].flatten()] + typed_onsite = {} + for at, tp in idp.chemical_symbol_to_type.items(): + onsite_tp_mask = onsite_block_mask[typed_dataset["atom_types"].flatten().eq(tp)] + onsite_tp = features[typed_dataset["atom_types"].flatten().eq(tp)] + filtered_vecs = torch.where(onsite_tp_mask, onsite_tp, torch.tensor(float('nan'))) + typed_onsite[at] = filtered_vecs + + # calculate norm & mean + typed_norm_ave = torch.ones(len(idp.chemical_symbol_to_type), idp.orbpair_irreps.num_irreps) + typed_norm_std = torch.zeros(len(idp.chemical_symbol_to_type), idp.orbpair_irreps.num_irreps) + typed_scalar_ave = torch.ones(len(idp.chemical_symbol_to_type), n_scalar) + typed_scalar_std = torch.zeros(len(idp.chemical_symbol_to_type), n_scalar) + for at, tp in idp.chemical_symbol_to_type.items(): + count_scalar = 0 + for ir, s in enumerate(irrep_slices): + sub_tensor = typed_onsite[at][:, s] + # dump the nan blocks here + if sub_tensor.shape[-1] == 1: + count_scalar += 1 + if not torch.isnan(sub_tensor).all(): + + norms = torch.norm(sub_tensor, p=2, dim=1) + typed_norm_ave[tp][ir] = norms.mean() + typed_norm_std[tp][ir] = norms.std() + if s.stop - s.start == 1: + # it's a scalar + typed_scalar_ave[tp][count_scalar-1] = sub_tensor.mean() + typed_scalar_std[tp][count_scalar-1] = sub_tensor.std() + + edge_stats = { + "norm_ave": typed_norm_ave, + "norm_std": typed_norm_std, + "scalar_ave": typed_scalar_ave, + "scalar_std": typed_scalar_std, + } + + return edge_stats \ No newline at end of file diff --git a/dptb/data/dataset/_hdf5_dataset.py b/dptb/data/dataset/_hdf5_dataset.py new file mode 100644 index 00000000..5fce39e2 --- /dev/null +++ b/dptb/data/dataset/_hdf5_dataset.py @@ -0,0 +1,171 @@ +from typing import Dict, Any, List, Callable, Union, Optional +from collections import defaultdict +import numpy as np + +import torch + +from .. import ( + AtomicData, + AtomicDataDict, +) +from ..transforms import TypeMapper +from ._base_datasets import AtomicDataset + + +class HDF5Dataset(AtomicDataset): + """A dataset that loads data from a HDF5 file. + + This class is useful for very large datasets that cannot fit in memory. It + efficiently loads data from disk as needed without everything needing to be + in memory at once. + + To use this, ``file_name`` should point to the HDF5 file, or alternatively a + semicolon separated list of multiple files. Each group in the file contains + samples that all have the same number of atoms. Typically there is one + group for each unique number of atoms, but that is not required. Each group + should contain arrays whose length equals the number of samples, one for each + type of data. The names of the arrays can be specified with ``key_mapping``. + + Args: + key_mapping (Dict[str, str]): mapping of array names in the HDF5 file to ``AtomicData`` keys + file_name (string): a semicolon separated list of HDF5 files. + """ + + def __init__( + self, + root: str, + key_mapping: Dict[str, str] = { + "pos": AtomicDataDict.POSITIONS_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + "forces": AtomicDataDict.FORCE_KEY, + "atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY, + "types": AtomicDataDict.ATOM_TYPE_KEY, + }, + file_name: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + type_mapper: Optional[TypeMapper] = None, + ): + super().__init__(root=root, type_mapper=type_mapper) + self.key_mapping = key_mapping + self.key_list = list(key_mapping.keys()) + self.value_list = list(key_mapping.values()) + self.file_name = file_name + self.r_max = AtomicData_options["r_max"] + self.index = None + self.num_frames = 0 + import h5py + + files = [h5py.File(f, "r") for f in self.file_name.split(";")] + for file in files: + for group_name in file: + for key in self.key_list: + if key in file[group_name]: + self.num_frames += len(file[group_name][key]) + break + file.close() + + def setup_index(self): + import h5py + + files = [h5py.File(f, "r") for f in self.file_name.split(";")] + self.has_forces = False + self.index = [] + for file in files: + for group_name in file: + group = file[group_name] + values = [None] * len(self.key_list) + samples = 0 + for i, key in enumerate(self.key_list): + if key in group: + values[i] = group[key] + samples = len(values[i]) + for i in range(samples): + self.index.append(tuple(values + [i])) + + def len(self) -> int: + return self.num_frames + + def get(self, idx: int) -> AtomicData: + if self.index is None: + self.setup_index() + data = self.index[idx] + i = data[-1] + args = {"r_max": self.r_max} + for j, value in enumerate(self.value_list): + if data[j] is not None: + args[value] = data[j][i] + return AtomicData.from_points(**args) + + def statistics( + self, + fields: List[Union[str, Callable]], + modes: List[str], + stride: int = 1, + unbiased: bool = True, + kwargs: Optional[Dict[str, dict]] = {}, + ) -> List[tuple]: + assert len(modes) == len(fields) + # TODO: use RunningStats + if len(fields) == 0: + return [] + if self.index is None: + self.setup_index() + results = [] + indices = self.indices() + if stride != 1: + indices = list(indices)[::stride] + for field, mode in zip(fields, modes): + count = 0 + if mode == "rms": + total = 0.0 + elif mode in ("mean_std", "per_atom_mean_std"): + total = [0.0, 0.0] + elif mode == "count": + counts = defaultdict(int) + else: + raise NotImplementedError(f"Analysis mode '{mode}' is not implemented") + for index in indices: + data = self.index[index] + i = data[-1] + if field in self.value_list: + values = data[self.value_list.index(field)][i] + elif callable(field): + values, _ = field(self.get(index)) + values = np.asarray(values) + else: + raise RuntimeError( + f"The field key `{field}` is not present in this dataset" + ) + length = len(values.flatten()) + if length == 1: + values = np.array([values]) + if mode == "rms": + total += np.sum(values * values) + count += length + elif mode == "count": + for v in values: + counts[v] += 1 + else: + if mode == "per_atom_mean_std": + values /= len(data[0][i]) + for v in values: + count += 1 + delta1 = v - total[0] + total[0] += delta1 / count + delta2 = v - total[0] + total[1] += delta1 * delta2 + if mode == "rms": + results.append(torch.tensor((np.sqrt(total / count),))) + elif mode == "count": + values = sorted(counts.keys()) + results.append( + (torch.tensor(values), torch.tensor([counts[v] for v in values])) + ) + else: + results.append( + ( + torch.tensor(total[0]), + torch.tensor(np.sqrt(total[1] / (count - 1))), + ) + ) + return results diff --git a/dptb/data/dataset/_npz_dataset.py b/dptb/data/dataset/_npz_dataset.py new file mode 100644 index 00000000..84080604 --- /dev/null +++ b/dptb/data/dataset/_npz_dataset.py @@ -0,0 +1,143 @@ +import numpy as np +from os.path import dirname, basename, abspath +from typing import Dict, Any, List, Optional + + +from .. import AtomicDataDict, _LONG_FIELDS, _NODE_FIELDS, _GRAPH_FIELDS +from ..transforms import TypeMapper +from ._base_datasets import AtomicInMemoryDataset + + +class NpzDataset(AtomicInMemoryDataset): + """Load data from an npz file. + + To avoid loading unneeded data, keys are ignored by default unless they are in ``key_mapping``, ``include_keys``, + or ``npz_fixed_fields_keys``. + + Args: + key_mapping (Dict[str, str]): mapping of npz keys to ``AtomicData`` keys. Optional + include_keys (list): the attributes to be processed and stored. Optional + npz_fixed_field_keys: the attributes that only have one instance but apply to all frames. Optional + Note that the mapped keys (as determined by the _values_ in ``key_mapping``) should be used in + ``npz_fixed_field_keys``, not the original npz keys from before mapping. If an npz key is not + present in ``key_mapping``, it is mapped to itself, and this point is not relevant. + + Example: Given a npz file with 10 configurations, each with 14 atoms. + + position: (10, 14, 3) + force: (10, 14, 3) + energy: (10,) + Z: (14) + user_label1: (10) # per config + user_label2: (10, 14, 3) # per atom + + The input yaml should be + + ```yaml + dataset: npz + dataset_file_name: example.npz + include_keys: + - user_label1 + - user_label2 + npz_fixed_field_keys: + - cell + - atomic_numbers + key_mapping: + position: pos + force: forces + energy: total_energy + Z: atomic_numbers + graph_fields: + - user_label1 + node_fields: + - user_label2 + ``` + + """ + + def __init__( + self, + root: str, + key_mapping: Dict[str, str] = { + "positions": AtomicDataDict.POSITIONS_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + "force": AtomicDataDict.FORCE_KEY, + "forces": AtomicDataDict.FORCE_KEY, + "Z": AtomicDataDict.ATOMIC_NUMBERS_KEY, + "atomic_number": AtomicDataDict.ATOMIC_NUMBERS_KEY, + }, + include_keys: List[str] = [], + npz_fixed_field_keys: List[str] = [], + file_name: Optional[str] = None, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + ): + self.key_mapping = key_mapping + self.npz_fixed_field_keys = npz_fixed_field_keys + self.include_keys = include_keys + + super().__init__( + file_name=file_name, + url=url, + root=root, + AtomicData_options=AtomicData_options, + include_frames=include_frames, + type_mapper=type_mapper, + ) + + @property + def raw_file_names(self): + return [basename(self.file_name)] + + @property + def raw_dir(self): + return dirname(abspath(self.file_name)) + + def get_data(self): + # get data returns either a list of AtomicData class or a data dict + + data = np.load(self.raw_dir + "/" + self.raw_file_names[0], allow_pickle=True) + + # only the keys explicitly mentioned in the yaml file will be parsed + keys = set(list(self.key_mapping.keys())) + + keys.update(self.npz_fixed_field_keys) + keys.update(self.include_keys) + keys = keys.intersection(set(list(data.keys()))) + + mapped = {self.key_mapping.get(k, k): data[k] for k in keys} + + for intkey in _LONG_FIELDS: + if intkey in mapped: + mapped[intkey] = mapped[intkey].astype(np.int64) + + fields = {k: v for k, v in mapped.items() if k not in self.npz_fixed_field_keys} + num_examples, num_atoms, n_dim = fields[AtomicDataDict.POSITIONS_KEY].shape + assert n_dim == 3 + + # now we replicate and add the fixed fields: + for fixed_field in self.npz_fixed_field_keys: + orig = mapped[fixed_field] + if fixed_field in _NODE_FIELDS: + assert orig.ndim >= 1 # [n_atom, feature_dims] + assert orig.shape[0] == num_atoms + replicated = np.expand_dims(orig, 0) + replicated = np.tile( + replicated, + (num_examples,) + (1,) * len(replicated.shape[1:]), + ) # [n_example, n_atom, feature_dims] + elif fixed_field in _GRAPH_FIELDS: + # orig is [feature_dims] + replicated = np.expand_dims(orig, 0) + replicated = np.tile( + replicated, + (num_examples,) + (1,) * len(replicated.shape[1:]), + ) # [n_example, feature_dims] + else: + raise KeyError( + f"npz_fixed_field_keys contains `{fixed_field}`, but it isn't registered as a node or graph field" + ) + fields[fixed_field] = replicated + return fields diff --git a/dptb/nnet/__init__.py b/dptb/data/interfaces/__init__.py similarity index 100% rename from dptb/nnet/__init__.py rename to dptb/data/interfaces/__init__.py diff --git a/dptb/data/interfaces/abacus.py b/dptb/data/interfaces/abacus.py new file mode 100644 index 00000000..6123ee5d --- /dev/null +++ b/dptb/data/interfaces/abacus.py @@ -0,0 +1,378 @@ +# Modified from script 'abasus_get_data.py' for interface from ABACUS to DeepH-pack +# To use this script, please add 'out_mat_hs2 1' in ABACUS INPUT File +# Current version is capable of coping with f-orbitals + +import os +import glob +import json +import re +from collections import Counter +from tqdm import tqdm + +import numpy as np +from scipy.sparse import csr_matrix +from scipy.linalg import block_diag +import h5py +import ase + +orbitalId = {0:'s',1:'p',2:'d',3:'f'} +Bohr2Ang = 0.529177249 + + +class OrbAbacus2DeepTB: + def __init__(self): + self.Us_abacus2deeptb = {} + self.Us_abacus2deeptb[0] = np.eye(1) + self.Us_abacus2deeptb[1] = np.eye(3)[[2, 0, 1]] # 0, 1, -1 -> -1, 0, 1 + self.Us_abacus2deeptb[2] = np.eye(5)[[4, 2, 0, 1, 3]] # 0, 1, -1, 2, -2 -> -2, -1, 0, 1, 2 + self.Us_abacus2deeptb[3] = np.eye(7)[[6, 4, 2, 0, 1, 3, 5]] # -3,-2,-1,0,1,2,3 + + minus_dict = { + 1: [0, 2], + 2: [1, 3], + 3: [0, 2, 4, 6], + } + + for k, v in minus_dict.items(): + self.Us_abacus2deeptb[k][v] *= -1 # add phase (-1)^m + + def get_U(self, l): + if l > 3: + raise NotImplementedError("Only support l = s, p, d, f") + return self.Us_abacus2deeptb[l] + + def transform(self, mat, l_lefts, l_rights): + block_lefts = block_diag(*[self.get_U(l_left) for l_left in l_lefts]) + block_rights = block_diag(*[self.get_U(l_right) for l_right in l_rights]) + return block_lefts @ mat @ block_rights.T + +def recursive_parse(input_path, + preprocess_dir, + data_name="OUT.ABACUS", + only_overlap=False, + parse_Hamiltonian=False, + add_overlap=False, + parse_eigenvalues=False, + prefix="data"): + """ + Parse ABACUS single point SCF calculation outputs. + Input: + `input_dir`: target dictionary(ies) containing "OUT.ABACUS" folder. + can be wildcard characters or a string list. + `preprocess_dir`: output dictionary of all processed data files. + `data_name`: output dictionary name of ABACUS, by default "OUT.ABACUS". + `only_overlap`: usually `False`. + set to `True` if the calculation job is getting overlap matrix ONLY. + `parse_Hamiltonian`: determine whether parsing the Hamiltonian `.csr` file or not. + `add_overlap`: determine whether parsing the overlap `.csr` file or not. + `parse_Hamiltonian` must be true to add overlap. + `parse_eigenvalues`: determine whether parsing `kpoints.dat` and `BAND_1.dat` or not. + that is, the k-points will always be loaded with bands. + `prefix`: prefix of the processed data folders' names. + """ + if isinstance(input_path, list) and all(isinstance(item, str) for item in input_path): + input_path = input_path + else: + input_path = glob.glob(input_path) + preprocess_dir = os.path.abspath(preprocess_dir) + os.makedirs(preprocess_dir, exist_ok=True) + # h5file_names = [] + + folders = [item for item in input_path if os.path.isdir(item)] + + with tqdm(total=len(folders)) as pbar: + for index, folder in enumerate(folders): + datafiles = os.listdir(folder) + if data_name in datafiles: + # The follwing `if` block is used by us only. + if os.path.exists(os.path.join(folder, data_name, "hscsr.tgz")): + os.system("cd "+os.path.join(folder, data_name) + " && tar -zxvf hscsr.tgz && mv OUT.ABACUS/* ./") + try: + _abacus_parse(folder, + os.path.join(preprocess_dir, f"{prefix}.{index}"), + data_name, + only_S=only_overlap, + get_Ham=parse_Hamiltonian, + add_overlap=add_overlap, + get_eigenvalues=parse_eigenvalues) + #h5file_names.append(os.path.join(file, "AtomicData.h5")) + pbar.update(1) + except Exception as e: + print(f"Error in {folder}/{data_name}: {e}") + continue + #return h5file_names + +def _abacus_parse(input_path, + output_path, + data_name, + only_S=False, + get_Ham=False, + add_overlap=False, + get_eigenvalues=False): + + input_path = os.path.abspath(input_path) + output_path = os.path.abspath(output_path) + os.makedirs(output_path, exist_ok=True) + + def find_target_line(f, target): + line = f.readline() + while line: + if target in line: + return line + line = f.readline() + return None + if only_S: + log_file_name = "running_get_S.log" + else: + log_file_name = "running_scf.log" + + with open(os.path.join(input_path, data_name, log_file_name), 'r') as f_chk: + lines = f_chk.readlines() + if not lines or " Total Time :" not in lines[-1]: + raise ValueError(f"Job is not normal ending!") + + with open(os.path.join(input_path, data_name, log_file_name), 'r') as f: + f.readline() + line = f.readline() + # assert "WELCOME TO ABACUS" in line + assert find_target_line(f, "READING UNITCELL INFORMATION") is not None, 'Cannot find "READING UNITCELL INFORMATION" in log file' + num_atom_type = int(f.readline().split()[-1]) + + assert find_target_line(f, "lattice constant (Bohr)") is not None + lattice_constant = float(f.readline().split()[-1]) # unit is Angstrom, didn't read (Bohr) here. + + site_norbits_dict = {} + orbital_types_dict = {} + for index_type in range(num_atom_type): + tmp = find_target_line(f, "READING ATOM TYPE") + assert tmp is not None, 'Cannot find "ATOM TYPE" in log file' + assert tmp.split()[-1] == str(index_type + 1) + if tmp is None: + raise Exception(f"Cannot find ATOM {index_type} in {log_file_name}") + + line = f.readline() + assert "atom label =" in line + atom_label = line.split()[-1] + assert atom_label in ase.data.atomic_numbers, "Atom label should be in periodic table" + atom_type = ase.data.atomic_numbers[atom_label] + + current_site_norbits = 0 + current_orbital_types = [] + while True: + line = f.readline() + if "number of zeta" in line: + tmp = line.split() + L = int(tmp[0][2:-1]) + num_L = int(tmp[-1]) + current_site_norbits += (2 * L + 1) * num_L + current_orbital_types.extend([L] * num_L) + else: + break + site_norbits_dict[atom_type] = current_site_norbits + orbital_types_dict[atom_type] = current_orbital_types + + #print(orbital_types_dict) + + line = find_target_line(f, "TOTAL ATOM NUMBER") + assert line is not None, 'Cannot find "TOTAL ATOM NUMBER" in log file' + nsites = int(line.split()[-1]) + + line = find_target_line(f, " COORDINATES") + assert line is not None, 'Cannot find "DIRECT COORDINATES" or "CARTESIAN COORDINATES" in log file' + if "DIRECT" in line: + coords_type = "direct" + elif "CARTESIAN" in line: + coords_type = "cartesian" + else: + raise ValueError('Cannot find "DIRECT COORDINATES" or "CARTESIAN COORDINATES" in log file') + + assert "atom" in f.readline() + frac_coords = np.zeros((nsites, 3)) + site_norbits = np.zeros(nsites, dtype=int) + element = np.zeros(nsites, dtype=int) + for index_site in range(nsites): + line = f.readline() + tmp = line.split() + assert "tau" in tmp[0] + atom_label = ''.join(re.findall(r'[A-Za-z]', tmp[0][5:])) + assert atom_label in ase.data.atomic_numbers, "Atom label should be in periodic table" + element[index_site] = ase.data.atomic_numbers[atom_label] + site_norbits[index_site] = site_norbits_dict[element[index_site]] + frac_coords[index_site, :] = np.array(tmp[1:4]) + norbits = int(np.sum(site_norbits)) + site_norbits_cumsum = np.cumsum(site_norbits) + + assert find_target_line(f, "Lattice vectors: (Cartesian coordinate: in unit of a_0)") is not None + lattice = np.zeros((3, 3)) + for index_lat in range(3): + lattice[index_lat, :] = np.array(f.readline().split()) + if coords_type == "cartesian": + frac_coords = frac_coords @ np.matrix(lattice).I # get frac_coords anyway + lattice = lattice * lattice_constant + + if only_S: + spinful = False + else: + line = find_target_line(f, "NSPIN") + assert line is not None, 'Cannot find "NSPIN" in log file' + if "NSPIN == 1" in line: + spinful = False + elif "NSPIN == 4" in line: + spinful = True + else: + raise ValueError(f'{line} is not supported') + + np.savetxt(os.path.join(output_path, "cell.dat"), lattice) + np.savetxt(os.path.join(output_path, "rcell.dat"), np.linalg.inv(lattice) * 2 * np.pi) + cart_coords = frac_coords @ lattice + np.savetxt(os.path.join(output_path, "positions.dat").format(output_path), cart_coords) + np.savetxt(os.path.join(output_path, "atomic_numbers.dat"), element, fmt='%d') + #info = {'nsites' : nsites, 'isorthogonal': False, 'isspinful': spinful, 'norbits': norbits} + #with open('{}/info.json'.format(output_path), 'w') as info_f: + # json.dump(info, info_f) + with open(os.path.join(output_path, "basis.dat"), 'w') as f: + for atomic_number in element: + counter = Counter(orbital_types_dict[atomic_number]) + num_orbs = [counter[i] for i in range(4)] # s, p, d, f + for index_l, l in enumerate(num_orbs): + if l == 0: # no this orbit + continue + if index_l == 0: + f.write(f"{l}{orbitalId[index_l]}") + else: + f.write(f"{l}{orbitalId[index_l]}") + f.write('\n') + atomic_basis = {} + for atomic_number, orbitals in orbital_types_dict.items(): + atomic_basis[ase.data.chemical_symbols[atomic_number]] = orbitals + + if get_Ham: + U_orbital = OrbAbacus2DeepTB() + def parse_matrix(matrix_path, factor, spinful=False): + matrix_dict = dict() + with open(matrix_path, 'r') as f: + line = f.readline() # read "Matrix Dimension of ..." + if not "Matrix Dimension of" in line: + line = f.readline() # ABACUS >= 3.0 + assert "Matrix Dimension of" in line + f.readline() # read "Matrix number of ..." + norbits = int(line.split()[-1]) + for line in f: + line1 = line.split() + if len(line1) == 0: + break + num_element = int(line1[3]) + if num_element != 0: + R_cur = np.array(line1[:3]).astype(int) + line2 = f.readline().split() + line3 = f.readline().split() + line4 = f.readline().split() + if not spinful: + hamiltonian_cur = csr_matrix((np.array(line2).astype(float), np.array(line3).astype(int), + np.array(line4).astype(int)), shape=(norbits, norbits)).toarray() + else: + line2 = np.char.replace(line2, '(', '') + line2 = np.char.replace(line2, ')', 'j') + line2 = np.char.replace(line2, ',', '+') + line2 = np.char.replace(line2, '+-', '-') + hamiltonian_cur = csr_matrix((np.array(line2).astype(np.complex128), np.array(line3).astype(int), + np.array(line4).astype(int)), shape=(norbits, norbits)).toarray() + for index_site_i in range(nsites): + for index_site_j in range(nsites): + key_str = f"{index_site_i + 1}_{index_site_j + 1}_{R_cur[0]}_{R_cur[1]}_{R_cur[2]}" + mat = hamiltonian_cur[(site_norbits_cumsum[index_site_i] + - site_norbits[index_site_i]) * (1 + spinful): + site_norbits_cumsum[index_site_i] * (1 + spinful), + (site_norbits_cumsum[index_site_j] - site_norbits[index_site_j]) * (1 + spinful): + site_norbits_cumsum[index_site_j] * (1 + spinful)] + if abs(mat).max() < 1e-10: + continue + if not spinful: + mat = U_orbital.transform(mat, orbital_types_dict[element[index_site_i]], + orbital_types_dict[element[index_site_j]]) + else: + mat = mat.reshape((site_norbits[index_site_i], 2, site_norbits[index_site_j], 2)) + mat = mat.transpose((1, 0, 3, 2)).reshape((2 * site_norbits[index_site_i], + 2 * site_norbits[index_site_j])) + mat = U_orbital.transform(mat, orbital_types_dict[element[index_site_i]] * 2, + orbital_types_dict[element[index_site_j]] * 2) + matrix_dict[key_str] = mat * factor + return matrix_dict, norbits + + if only_S: + overlap_dict, tmp = parse_matrix(os.path.join(input_path, "SR.csr"), 1) + assert tmp == norbits + else: + hamiltonian_dict, tmp = parse_matrix( + os.path.join(input_path, data_name, "data-HR-sparse_SPIN0.csr"), 13.605698, # Ryd2eV + spinful=spinful) + assert tmp == norbits * (1 + spinful) + overlap_dict, tmp = parse_matrix(os.path.join(input_path, data_name, "data-SR-sparse_SPIN0.csr"), 1, + spinful=spinful) + assert tmp == norbits * (1 + spinful) + if spinful: + overlap_dict_spinless = {} + for k, v in overlap_dict.items(): + overlap_dict_spinless[k] = v[:v.shape[0] // 2, :v.shape[1] // 2].real + overlap_dict_spinless, overlap_dict = overlap_dict, overlap_dict_spinless + + if not only_S: + with h5py.File(os.path.join(output_path, "hamiltonians.h5"), 'w') as fid: + # creating a default group here adapting to the format used in DefaultDataset. + # by the way DefaultDataset loading h5 file, the index should be "1" here. + default_group = fid.create_group("1") + for key_str, value in hamiltonian_dict.items(): + default_group[key_str] = value + if add_overlap: + with h5py.File(os.path.join(output_path, "overlaps.h5"), 'w') as fid: + default_group = fid.create_group("1") + for key_str, value in overlap_dict.items(): + default_group[key_str] = value + + if get_eigenvalues: + kpts = [] + with open(os.path.join(input_path, data_name, "kpoints"), "r") as f: + nkstot = f.readline().strip().split()[-1] + f.readline() + for _ in range(int(nkstot)): + line = f.readline() + kpt = [] + line = line.strip().split() + kpt.extend([float(line[1]), float(line[2]), float(line[3])]) + kpts.append(kpt) + kpts = np.array(kpts) + + with open(os.path.join(input_path, data_name, "BANDS_1.dat"), "r") as file: + band_lines = file.readlines() + band = [] + for line in band_lines: + values = line.strip().split() + eigs = [float(value) for value in values[1:]] + band.append(eigs) + band = np.array(band) + + assert len(band) == len(kpts) + np.save(os.path.join(output_path, "kpoints.npy"), kpts) + np.save(os.path.join(output_path, "eigenvalues.npy"), band) + + #with h5py.File(os.path.join(output_path, "AtomicData.h5"), "w") as f: + # f["cell"] = lattice + # f["pos"] = cart_coords + # f["atomic_numbers"] = element + # basis = f.create_group("basis") + # for key, value in atomic_basis.items(): + # basis[key] = value + # if get_Ham: + # f["hamiltonian_blocks"] = h5py.ExternalLink("hamiltonians.h5", "/") + # if add_overlap: + # f["overlap_blocks"] = h5py.ExternalLink("overlaps.h5", "/") + # # else: + # # f["overlap_blocks"] = False + # # else: + # # f["hamiltonian_blocks"] = False + # if get_eigenvalues: + # f["kpoints"] = kpts + # f["eigenvalues"] = band + # # else: + # # f["kpoint"] = False + # # f["eigenvalue"] = False diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py new file mode 100644 index 00000000..710f8b49 --- /dev/null +++ b/dptb/data/interfaces/ham_to_feature.py @@ -0,0 +1,209 @@ +from .. import _keys +import ase +import numpy as np +import torch +import re +import e3nn.o3 as o3 +import h5py +import logging +from dptb.utils.constants import anglrMId + +log = logging.getLogger(__name__) + +def ham_block_to_feature(data, idp, Hamiltonian_blocks, overlap_blocks=False): + # Hamiltonian_blocks should be a h5 group in the current version + onsite_ham = [] + edge_ham = [] + if overlap_blocks: + edge_overlap = [] + + idp.get_orbital_maps() + idp.get_orbpair_maps() + + atomic_numbers = data[_keys.ATOMIC_NUMBERS_KEY] + + # onsite features + for atom in range(len(atomic_numbers)): + block_index = '_'.join(map(str, map(int, [atom+1, atom+1] + list([0, 0, 0])))) + try: + block = Hamiltonian_blocks[block_index] + except: + raise IndexError("Hamiltonian block for onsite not found, check Hamiltonian file.") + + symbol = ase.data.chemical_symbols[atomic_numbers[atom]] + basis_list = idp.basis[symbol] + onsite_out = np.zeros(idp.reduced_matrix_element) + + for index, basis_i in enumerate(basis_list): + slice_i = idp.orbital_maps[symbol][basis_i] + for basis_j in basis_list[index:]: + slice_j = idp.orbital_maps[symbol][basis_j] + block_ij = block[slice_i, slice_j] + full_basis_i = idp.basis_to_full_basis[symbol][basis_i] + full_basis_j = idp.basis_to_full_basis[symbol][basis_j] + + # fill onsite vector + pair_ij = full_basis_i + "-" + full_basis_j + feature_slice = idp.orbpair_maps[pair_ij] + onsite_out[feature_slice] = block_ij.flatten() + + onsite_ham.append(onsite_out) + #onsite_ham = np.array(onsite_ham) + + # edge features + edge_index = data[_keys.EDGE_INDEX_KEY] + edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY] + + for atom_i, atom_j, R_shift in zip(edge_index[0], edge_index[1], edge_cell_shift): + block_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(R_shift)))) + symbol_i = ase.data.chemical_symbols[atomic_numbers[atom_i]] + symbol_j = ase.data.chemical_symbols[atomic_numbers[atom_j]] + + # try: + # block = Hamiltonian_blocks[block_index] + # if overlap_blocks: + # block_s = overlap_blocks[block_index] + # except: + # raise IndexError("Hamiltonian block for hopping not found, r_cut may be too big for input R.") + + block = Hamiltonian_blocks.get(block_index, 0) + if overlap_blocks: + block_s = overlap_blocks.get(block_index, 0) + if block == 0: + block = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j]) + log.warning("Hamiltonian block for hopping {} not found, r_cut may be too big for input R.".format(block_index)) + if overlap_blocks: + if block_s == 0: + block_s = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j]) + log.warning("Overlap block for hopping {} not found, r_cut may be too big for input R.".format(block_index)) + + assert block.shape == (idp.norbs[symbol_i], idp.norbs[symbol_j]) + + + basis_i_list = idp.basis[symbol_i] + basis_j_list = idp.basis[symbol_j] + hopping_out = np.zeros(idp.reduced_matrix_element) + if overlap_blocks: + overlap_out = np.zeros(idp.reduced_matrix_element) + + for basis_i in basis_i_list: + slice_i = idp.orbital_maps[symbol_i][basis_i] + for basis_j in basis_j_list: + slice_j = idp.orbital_maps[symbol_j][basis_j] + full_basis_i = idp.basis_to_full_basis[symbol_i][basis_i] + full_basis_j = idp.basis_to_full_basis[symbol_j][basis_j] + if idp.full_basis.index(full_basis_i) <= idp.full_basis.index(full_basis_j): + block_ij = block[slice_i, slice_j] + if overlap_blocks: + block_s_ij = block_s[slice_i, slice_j] + + # fill hopping vector + pair_ij = full_basis_i + "-" + full_basis_j + feature_slice = idp.orbpair_maps[pair_ij] + + hopping_out[feature_slice] = block_ij.flatten() + if overlap_blocks: + overlap_out[feature_slice] = block_s_ij.flatten() + + edge_ham.append(hopping_out) + if overlap_blocks: + edge_overlap.append(overlap_out) + + data[_keys.NODE_FEATURES_KEY] = torch.as_tensor(np.array(onsite_ham), dtype=torch.get_default_dtype()) + data[_keys.EDGE_FEATURES_KEY] = torch.as_tensor(np.array(edge_ham), dtype=torch.get_default_dtype()) + if overlap_blocks: + data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype()) + + +def openmx_to_deeptb(data, idp, openmx_hpath): + # Hamiltonian_blocks should be a h5 group in the current version + Us_openmx2wiki = { + "s": torch.eye(1).double(), + "p": torch.eye(3)[[1, 2, 0]].double(), + "d": torch.eye(5)[[2, 4, 0, 3, 1]].double(), + "f": torch.eye(7)[[6, 4, 2, 0, 1, 3, 5]].double() + } + # init_rot_mat + rot_blocks = {} + for asym, orbs in idp.basis.items(): + b = [Us_openmx2wiki[re.findall(r"[A-Za-z]", orb)[0]] for orb in orbs] + rot_blocks[asym] = torch.block_diag(*b) + + + Hamiltonian_blocks = h5py.File(openmx_hpath, 'r') + + onsite_ham = [] + edge_ham = [] + + idp.get_orbital_maps() + idp.get_orbpair_maps() + + atomic_numbers = data[_keys.ATOMIC_NUMBERS_KEY] + + # onsite features + for atom in range(len(atomic_numbers)): + block_index = str([0, 0, 0, atom+1, atom+1]) + try: + block = Hamiltonian_blocks[block_index][:] + except: + raise IndexError("Hamiltonian block for onsite not found, check Hamiltonian file.") + + + symbol = ase.data.chemical_symbols[atomic_numbers[atom]] + block = rot_blocks[symbol] @ block @ rot_blocks[symbol].T + basis_list = idp.basis[symbol] + onsite_out = np.zeros(idp.reduced_matrix_element) + + for index, basis_i in enumerate(basis_list): + slice_i = idp.orbital_maps[symbol][basis_i] + for basis_j in basis_list[index:]: + slice_j = idp.orbital_maps[symbol][basis_j] + block_ij = block[slice_i, slice_j] + full_basis_i = idp.basis_to_full_basis[symbol][basis_i] + full_basis_j = idp.basis_to_full_basis[symbol][basis_j] + + # fill onsite vector + pair_ij = full_basis_i + "-" + full_basis_j + feature_slice = idp.orbpair_maps[pair_ij] + onsite_out[feature_slice] = block_ij.flatten() + + onsite_ham.append(onsite_out) + #onsite_ham = np.array(onsite_ham) + + # edge features + edge_index = data[_keys.EDGE_INDEX_KEY] + edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY] + + for atom_i, atom_j, R_shift in zip(edge_index[0], edge_index[1], edge_cell_shift): + block_index = str(list(R_shift.int().numpy())+[int(atom_i)+1, int(atom_j)+1]) + try: + block = Hamiltonian_blocks[block_index][:] + except: + raise IndexError("Hamiltonian block for hopping not found, r_cut may be too big for input R.") + + symbol_i = ase.data.chemical_symbols[atomic_numbers[atom_i]] + symbol_j = ase.data.chemical_symbols[atomic_numbers[atom_j]] + block = rot_blocks[symbol_i] @ block @ rot_blocks[symbol_j].T + basis_i_list = idp.basis[symbol_i] + basis_j_list = idp.basis[symbol_j] + hopping_out = np.zeros(idp.reduced_matrix_element) + + for basis_i in basis_i_list: + slice_i = idp.orbital_maps[symbol_i][basis_i] + for basis_j in basis_j_list: + slice_j = idp.orbital_maps[symbol_j][basis_j] + block_ij = block[slice_i, slice_j] + full_basis_i = idp.basis_to_full_basis[symbol_i][basis_i] + full_basis_j = idp.basis_to_full_basis[symbol_j][basis_j] + + if idp.full_basis.index(full_basis_i) <= idp.full_basis.index(full_basis_j): + # fill hopping vector + pair_ij = full_basis_i + "-" + full_basis_j + feature_slice = idp.orbpair_maps[pair_ij] + hopping_out[feature_slice] = block_ij.flatten() + + edge_ham.append(hopping_out) + + data[_keys.NODE_FEATURES_KEY] = torch.as_tensor(np.array(onsite_ham), dtype=torch.get_default_dtype()) + data[_keys.EDGE_FEATURES_KEY] = torch.as_tensor(np.array(edge_ham), dtype=torch.get_default_dtype()) + Hamiltonian_blocks.close() \ No newline at end of file diff --git a/dptb/data/test_data.py b/dptb/data/test_data.py new file mode 100644 index 00000000..7427cc4b --- /dev/null +++ b/dptb/data/test_data.py @@ -0,0 +1,83 @@ +from typing import Optional, List, Dict, Any, Tuple +import copy + +import numpy as np + +import ase +import ase.build +from ase.calculators.emt import EMT + +from dptb.data import AtomicInMemoryDataset, AtomicData +from .transforms import TypeMapper + + +class EMTTestDataset(AtomicInMemoryDataset): + """Test dataset with PBC based on the toy EMT potential included in ASE. + + Randomly generates (in a reproducable manner) a basic bulk with added + Gaussian noise around equilibrium positions. + + In ASE units (eV/Å). + """ + + def __init__( + self, + root: str, + supercell: Tuple[int, int, int] = (4, 4, 4), + sigma: float = 0.1, + element: str = "Cu", + num_frames: int = 10, + dataset_seed: int = 123456, + file_name: Optional[str] = None, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + ): + # Set properties for hashing + assert element in ("Cu", "Pd", "Au", "Pt", "Al", "Ni", "Ag") + self.element = element + self.sigma = sigma + self.supercell = tuple(supercell) + self.num_frames = num_frames + self.dataset_seed = dataset_seed + + super().__init__( + file_name=file_name, + url=url, + root=root, + AtomicData_options=AtomicData_options, + include_frames=include_frames, + type_mapper=type_mapper, + ) + + @property + def raw_file_names(self): + return [] + + @property + def raw_dir(self): + return "raw" + + def get_data(self): + rng = np.random.default_rng(self.dataset_seed) + base_atoms = ase.build.bulk(self.element, "fcc").repeat(self.supercell) + base_atoms.calc = EMT() + orig_pos = copy.deepcopy(base_atoms.positions) + datas = [] + for _ in range(self.num_frames): + base_atoms.positions[:] = orig_pos + base_atoms.positions += rng.normal( + loc=0.0, scale=self.sigma, size=base_atoms.positions.shape + ) + + datas.append( + AtomicData.from_ase( + base_atoms.copy(), + forces=base_atoms.get_forces(), + total_energy=base_atoms.get_potential_energy(), + stress=base_atoms.get_stress(voigt=False), + **self.AtomicData_options + ) + ) + return datas diff --git a/dptb/data/transforms.py b/dptb/data/transforms.py new file mode 100644 index 00000000..4af4c4da --- /dev/null +++ b/dptb/data/transforms.py @@ -0,0 +1,705 @@ +from typing import Dict, Optional, Union, List +from dptb.data.AtomicDataDict import Type +from dptb.utils.tools import get_uniq_symbol +from dptb.utils.constants import anglrMId +import re +import warnings + +import torch + +import ase.data +import e3nn.o3 as o3 + +from dptb.data import AtomicData, AtomicDataDict + + +class TypeMapper: + """Based on a configuration, map atomic numbers to types.""" + + num_types: int + chemical_symbol_to_type: Optional[Dict[str, int]] + type_to_chemical_symbol: Optional[Dict[int, str]] + type_names: List[str] + _min_Z: int + + def __init__( + self, + type_names: Optional[List[str]] = None, + chemical_symbol_to_type: Optional[Dict[str, int]] = None, + type_to_chemical_symbol: Optional[Dict[int, str]] = None, + chemical_symbols: Optional[List[str]] = None, + device=torch.device("cpu"), + ): + self.device = device + if chemical_symbols is not None: + if chemical_symbol_to_type is not None: + raise ValueError( + "Cannot provide both `chemical_symbols` and `chemical_symbol_to_type`" + ) + # repro old, sane NequIP behaviour + # checks also for validity of keys + atomic_nums = [ase.data.atomic_numbers[sym] for sym in chemical_symbols] + # https://stackoverflow.com/questions/29876580/how-to-sort-a-list-according-to-another-list-python + chemical_symbols = [ + e[1] for e in sorted(zip(atomic_nums, chemical_symbols)) # low to high + ] + chemical_symbol_to_type = {k: i for i, k in enumerate(chemical_symbols)} + del chemical_symbols + + if type_to_chemical_symbol is not None: + type_to_chemical_symbol = { + int(k): v for k, v in type_to_chemical_symbol.items() + } + assert all( + v in ase.data.chemical_symbols for v in type_to_chemical_symbol.values() + ) + + # Build from chem->type mapping, if provided + self.chemical_symbol_to_type = chemical_symbol_to_type + if self.chemical_symbol_to_type is not None: + # Validate + for sym, type in self.chemical_symbol_to_type.items(): + assert sym in ase.data.atomic_numbers, f"Invalid chemical symbol {sym}" + assert 0 <= type, f"Invalid type number {type}" + assert set(self.chemical_symbol_to_type.values()) == set( + range(len(self.chemical_symbol_to_type)) + ) + if type_names is None: + # Make type_names + type_names = [None] * len(self.chemical_symbol_to_type) + for sym, type in self.chemical_symbol_to_type.items(): + type_names[type] = sym + else: + # Make sure they agree on types + # We already checked that chem->type is contiguous, + # so enough to check length since type_names is a list + assert len(type_names) == len(self.chemical_symbol_to_type) + # Make mapper array + valid_atomic_numbers = [ + ase.data.atomic_numbers[sym] for sym in self.chemical_symbol_to_type + ] + self._min_Z = min(valid_atomic_numbers) + self._max_Z = max(valid_atomic_numbers) + Z_to_index = torch.full( + size=(1 + self._max_Z - self._min_Z,), fill_value=-1, dtype=torch.long, device=device + ) + for sym, type in self.chemical_symbol_to_type.items(): + Z_to_index[ase.data.atomic_numbers[sym] - self._min_Z] = type + self._Z_to_index = Z_to_index + self._index_to_Z = torch.zeros( + size=(len(self.chemical_symbol_to_type),), dtype=torch.long, device=device + ) + for sym, type_idx in self.chemical_symbol_to_type.items(): + self._index_to_Z[type_idx] = ase.data.atomic_numbers[sym] + self._valid_set = set(valid_atomic_numbers) + true_type_to_chemical_symbol = { + type_id: sym for sym, type_id in self.chemical_symbol_to_type.items() + } + if type_to_chemical_symbol is not None: + assert type_to_chemical_symbol == true_type_to_chemical_symbol + else: + type_to_chemical_symbol = true_type_to_chemical_symbol + + # check + if type_names is None: + raise ValueError( + "None of chemical_symbols, chemical_symbol_to_type, nor type_names was provided; exactly one is required" + ) + # validate type names + assert all( + n.isalnum() for n in type_names + ), "Type names must contain only alphanumeric characters" + # Set to however many maps specified -- we already checked contiguous + self.num_types = len(type_names) + # Check type_names + self.type_names = type_names + self.type_to_chemical_symbol = type_to_chemical_symbol + if self.type_to_chemical_symbol is not None: + assert set(type_to_chemical_symbol.keys()) == set(range(self.num_types)) + + def __call__( + self, data: Union[AtomicDataDict.Type, AtomicData], types_required: bool = True + ) -> Union[AtomicDataDict.Type, AtomicData]: + if AtomicDataDict.ATOM_TYPE_KEY in data: + if AtomicDataDict.ATOMIC_NUMBERS_KEY in data: + warnings.warn( + "Data contained both ATOM_TYPE_KEY and ATOMIC_NUMBERS_KEY; ignoring ATOMIC_NUMBERS_KEY" + ) + elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data: + assert ( + self.chemical_symbol_to_type is not None + ), "Atomic numbers provided but there is no chemical_symbols/chemical_symbol_to_type mapping!" + atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY] + del data[AtomicDataDict.ATOMIC_NUMBERS_KEY] + + data[AtomicDataDict.ATOM_TYPE_KEY] = self.transform(atomic_numbers) + else: + if types_required: + raise KeyError( + "Data doesn't contain any atom type information (ATOM_TYPE_KEY or ATOMIC_NUMBERS_KEY)" + ) + return data + + def transform(self, atomic_numbers): + """core function to transform an array to specie index list""" + + if atomic_numbers.min() < self._min_Z or atomic_numbers.max() > self._max_Z: + bad_set = set(torch.unique(atomic_numbers).cpu().tolist()) - self._valid_set + raise ValueError( + f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!" + ) + + return self._Z_to_index.to(device=atomic_numbers.device)[ + atomic_numbers - self._min_Z + ] + + def untransform(self, atom_types): + """Transform atom types back into atomic numbers""" + return self._index_to_Z[atom_types].to(device=atom_types.device) + + @property + def has_chemical_symbols(self) -> bool: + return self.chemical_symbol_to_type is not None + + @staticmethod + def format( + data: list, type_names: List[str], element_formatter: str = ".6f" + ) -> str: + """ + Formats a list of data elements along with their type names. + + Parameters: + data (list): The data elements to be formatted. This should be a list of numbers. + type_names (List[str]): The type names corresponding to the data elements. This should be a list of strings. + element_formatter (str, optional): The format in which the data elements should be displayed. Defaults to ".6f". + + Returns: + str: A string representation of the data elements along with their type names. + + Raises: + ValueError: If `data` is not None, not 0-dimensional, or not 1-dimensional with length equal to the length of `type_names`. + + Example: + >>> data = [1.123456789, 2.987654321] + >>> type_names = ['Type1', 'Type2'] + >>> print(TypeMapper.format(data, type_names)) + [Type1: 1.123457, Type2: 2.987654] + """ + data = torch.as_tensor(data) if data is not None else None + if data is None: + return f"[{', '.join(type_names)}: None]" + elif data.ndim == 0: + return (f"[{', '.join(type_names)}: {{:{element_formatter}}}]").format(data) + elif data.ndim == 1 and len(data) == len(type_names): + return ( + "[" + + ", ".join( + f"{{{i}[0]}}: {{{i}[1]:{element_formatter}}}" + for i in range(len(data)) + ) + + "]" + ).format(*zip(type_names, data)) + else: + raise ValueError( + f"Don't know how to format data=`{data}` for types {type_names} with element_formatter=`{element_formatter}`" + ) + + +class BondMapper(TypeMapper): + def __init__( + self, + chemical_symbols: Optional[List[str]] = None, + chemical_symbols_to_type:Union[Dict[str, int], None]=None, + device=torch.device("cpu"), + ): + super(BondMapper, self).__init__(chemical_symbol_to_type=chemical_symbols_to_type, chemical_symbols=chemical_symbols, device=device) + + self.bond_types = [None] * self.num_types ** 2 + self.reduced_bond_types = [None] * ((self.num_types * (self.num_types + 1)) // 2) + self.bond_to_type = {} + self.type_to_bond = {} + self.reduced_bond_to_type = {} + self.type_to_reduced_bond = {} + for asym, ai in self.chemical_symbol_to_type.items(): + for bsym, bi in self.chemical_symbol_to_type.items(): + self.bond_types[ai * self.num_types + bi] = asym + "-" + bsym + if ai <= bi: + self.reduced_bond_types[(2*self.num_types-ai+1) * ai // 2 + bi-ai] = asym + "-" + bsym + for i, bt in enumerate(self.bond_types): + self.bond_to_type[bt] = i + self.type_to_bond[i] = bt + for i, bt in enumerate(self.reduced_bond_types): + self.reduced_bond_to_type[bt] = i + self.type_to_reduced_bond[i] = bt + + ZZ_to_index = torch.full( + size=(len(self._Z_to_index), len(self._Z_to_index)), fill_value=-1, device=device, dtype=torch.long + ) + ZZ_to_reduced_index = torch.full( + size=(len(self._Z_to_index), len(self._Z_to_index)), fill_value=-1, device=device, dtype=torch.long + ) + + + for abond, aidx in self.bond_to_type.items(): # type_names has a ascending order according to atomic number + asym, bsym = abond.split("-") + ZZ_to_index[ase.data.atomic_numbers[asym]-self._min_Z, ase.data.atomic_numbers[bsym]-self._min_Z] = aidx + + for abond, aidx in self.reduced_bond_to_type.items(): # type_names has a ascending order according to atomic number + asym, bsym = abond.split("-") + ZZ_to_reduced_index[ase.data.atomic_numbers[asym]-self._min_Z, ase.data.atomic_numbers[bsym]-self._min_Z] = aidx + + + self._ZZ_to_index = ZZ_to_index + self._ZZ_to_reduced_index = ZZ_to_reduced_index + + self._index_to_ZZ = torch.zeros( + size=(len(self.bond_to_type),2), dtype=torch.long, device=device + ) + self._reduced_index_to_ZZ = torch.zeros( + size=(len(self.reduced_bond_to_type),2), dtype=torch.long, device=device + ) + + for abond, aidx in self.bond_to_type.items(): + asym, bsym = abond.split("-") + self._index_to_ZZ[aidx] = torch.tensor([ase.data.atomic_numbers[asym], ase.data.atomic_numbers[bsym]], dtype=torch.long, device=device) + + for abond, aidx in self.reduced_bond_to_type.items(): + asym, bsym = abond.split("-") + self._reduced_index_to_ZZ[aidx] = torch.tensor([ase.data.atomic_numbers[asym], ase.data.atomic_numbers[bsym]], dtype=torch.long, device=device) + + + def transform_atom(self, atomic_numbers): + return self.transform(atomic_numbers) + + def transform_bond(self, iatomic_numbers, jatomic_numbers): + + if iatomic_numbers.device != jatomic_numbers.device: + raise ValueError("iatomic_numbers and jatomic_numbers should be on the same device!") + + if iatomic_numbers.min() < self._min_Z or iatomic_numbers.max() > self._max_Z: + bad_set = set(torch.unique(iatomic_numbers).cpu().tolist()) - self._valid_set + raise ValueError( + f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!" + ) + + if jatomic_numbers.min() < self._min_Z or jatomic_numbers.max() > self._max_Z: + bad_set = set(torch.unique(jatomic_numbers).cpu().tolist()) - self._valid_set + raise ValueError( + f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!" + ) + + return self._ZZ_to_index.to(device=iatomic_numbers.device)[ + iatomic_numbers - self._min_Z, jatomic_numbers - self._min_Z + ] + + def transform_reduced_bond(self, iatomic_numbers, jatomic_numbers): + + if iatomic_numbers.device != jatomic_numbers.device: + raise ValueError("iatomic_numbers and jatomic_numbers should be on the same device!") + + if iatomic_numbers.min() < self._min_Z or iatomic_numbers.max() > self._max_Z: + bad_set = set(torch.unique(iatomic_numbers).cpu().tolist()) - self._valid_set + raise ValueError( + f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!" + ) + + if jatomic_numbers.min() < self._min_Z or jatomic_numbers.max() > self._max_Z: + bad_set = set(torch.unique(jatomic_numbers).cpu().tolist()) - self._valid_set + raise ValueError( + f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!" + ) + + return self._ZZ_to_reduced_index.to(device=iatomic_numbers.device)[ + iatomic_numbers - self._min_Z, jatomic_numbers - self._min_Z + ] + + def untransform_atom(self, atom_types): + """Transform atom types back into atomic numbers""" + return self.untransform(atom_types) + + def untransform_bond(self, bond_types): + """Transform bond types back into atomic numbers""" + return self._index_to_ZZ[bond_types].to(device=bond_types.device) + + def untransform_reduced_bond(self, bond_types): + """Transform reduced bond types back into atomic numbers""" + return self._reduced_index_to_ZZ[bond_types].to(device=bond_types.device) + + @property + def has_bond(self) -> bool: + return self.bond_to_type is not None + + def __call__( + self, data: Union[AtomicDataDict.Type, AtomicData], types_required: bool = True + ) -> Union[AtomicDataDict.Type, AtomicData]: + if AtomicDataDict.EDGE_TYPE_KEY in data: + if AtomicDataDict.ATOMIC_NUMBERS_KEY in data: + warnings.warn( + "Data contained both EDGE_TYPE_KEY and ATOMIC_NUMBERS_KEY; ignoring ATOMIC_NUMBERS_KEY" + ) + elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data: + assert ( + self.bond_to_type is not None + ), "Atomic numbers provided but there is no chemical_symbols/chemical_symbol_to_type mapping!" + atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY] + + assert ( + AtomicDataDict.EDGE_INDEX_KEY in data + ), "The bond type mapper need a EDGE index as input." + + data[AtomicDataDict.EDGE_TYPE_KEY] = \ + self.transform_bond( + atomic_numbers[data[AtomicDataDict.EDGE_INDEX_KEY][0]], + atomic_numbers[data[AtomicDataDict.EDGE_INDEX_KEY][1]] + ) + else: + if types_required: + raise KeyError( + "Data doesn't contain any atom type information (EDGE_TYPE_KEY or ATOMIC_NUMBERS_KEY)" + ) + return super().__call__(data=data, types_required=types_required) + + + + + +class OrbitalMapper(BondMapper): + def __init__( + self, + basis: Dict[str, Union[List[str], str]], + chemical_symbol_to_type: Optional[Dict[str, int]] = None, + method: str ="e3tb", + device: Union[str, torch.device] = torch.device("cpu") + ): + + """ + This class is used to map the orbital pair index to the index of the reduced matrix element (or sk integrals when method is sktb). To construct a reduced matrix element features in each edge/node with equal sizes as well as their mappings, the following steps will be conducted: + + 1. The basis of each atom will be sorted according to their names. For example, The basis ["2s", "1s", "s*", "2p"] of atom A will be sorted as ["s*", "1s", "2s", "2p"]. + + 2. The sorted basis will be transformed into a general basis, dubbed as full_basis. It is the least required set covering all the basis number and types of each atom. The basis will be renamed according to their angular momentum and the order after sorting. Take s orbital as a example, the first s* will be named as "1s", the second s* will be named as "2s", and so on. Same for p, d, f orbitals. + + Then the mappings and masks used to guide the construction of hamiltonian will be constructed. The mappings includes: + + Mappings: + fullbasis_to_basis, basis_to_fullbasis: which function as their names + orbpair_maps: the mapping from orbital pairs of full basis to the reduced matrix element (or sk integrals) index. + orbpairtype_maps: the mapping from the types of orbital pair (e.g. "s-s", "s-p", "p-p") to the reduced matrix element (or sk integrals) index. + skonsite_maps: the mapping from the orbital to the sk onsite energies index. + skonsitetype_maps: the mapping from the orbital type (e.g. "s", "p", "d", "f") to the sk onsite energies index. + orbital_maps: the mapping from the orbital to the index of the corresponding lines/column in hamiltonian blocks. + orbpair_irreps: the e3nn irreducible representations of the full reduced matrix element edge/node features. + + Masks: + mask_to_basis: the mask used to map the (line/column of) hamiltonian of full basis to the (line/column of) block of original basis of each atom. + mask_to_erme: the mask used to map the hopping block's flattened reduced matrix element (up tri-diagonal block of hamiltonian) of full basis to it of the original basis. + mask_to_nrme: the mask used to map the onsite block's flattened reduced matrix element (diagonal block of hamiltonian) of full basis to it of the original basis. + + Parameters + ---------- + basis : dict + the definition of the basis set, should be like: + {"A":"2s2p3d1f", "B":"1s2f3d1f"} or + {"A":["2s", "2p"], "B":["2s", "2p"]} + when list, "2s" indicate a "s" orbital in the second shell. + when str, "2s" indicates two s orbitals, + "2s2p3d4f" is equivilent to ["1s","2s", "1p", "2p", "1d", "2d", "3d", "1f"] + """ + + #TODO: use OrderedDict to fix the order of the dict used as index map + if chemical_symbol_to_type is not None: + assert set(basis.keys()) == set(chemical_symbol_to_type.keys()) + super(OrbitalMapper, self).__init__(chemical_symbol_to_type=chemical_symbol_to_type, device=device) + else: + super(OrbitalMapper, self).__init__(chemical_symbols=list(basis.keys()), device=device) + + self.basis = basis + self.method = method + self.device = device + + if self.method not in ["e3tb", "sktb"]: + raise ValueError + + if isinstance(self.basis[self.type_names[0]], str): + orbtype_count = {"s":0, "p":0, "d":0, "f":0} + orbs = map(lambda bs: re.findall(r'[1-9]+[A-Za-z]', bs), self.basis.values()) + for ib in orbs: + for io in ib: + if int(io[0]) > orbtype_count[io[1]]: + orbtype_count[io[1]] = int(io[0]) + # split into list basis + basis = {k:[] for k in self.type_names} + for ib in self.basis.keys(): + for io in ["s", "p", "d", "f"]: + if io in self.basis[ib]: + basis[ib].extend([str(i)+io for i in range(1, int(re.findall(r'[1-9]+'+io, self.basis[ib])[0][0])+1)]) + self.basis = basis + + elif isinstance(self.basis[self.type_names[0]], list): + nb = len(self.type_names) + orbtype_count = {"s":[0]*nb, "p":[0]*nb, "d":[0]*nb, "f":[0]*nb} + for ib, bt in enumerate(self.type_names): + for io in self.basis[bt]: + orb = re.findall(r'[A-Za-z]', io)[0] + orbtype_count[orb][ib] += 1 + + for ko in orbtype_count.keys(): + orbtype_count[ko] = max(orbtype_count[ko]) + + self.orbtype_count = orbtype_count + self.full_basis_norb = 1 * orbtype_count["s"] + 3 * orbtype_count["p"] + 5 * orbtype_count["d"] + 7 * orbtype_count["f"] + + + if self.method == "e3tb": + self.reduced_matrix_element = int(((orbtype_count["s"] + 9 * orbtype_count["p"] + 25 * orbtype_count["d"] + 49 * orbtype_count["f"]) + \ + self.full_basis_norb ** 2)/2) # reduce onsite elements by blocks. we cannot reduce it by element since the rme will pass into CG basis to form the whole block + else: + # two factor: this outside one is the number of min(l,l')+1, ie. the number of sk integrals for each orbital pair. + # the inside one the type of bond considering the interaction between different orbitals. s-p -> p-s. there are 2 types of bond. and 1 type of s-s. + self.reduced_matrix_element = ( + 1 * orbtype_count["s"] * orbtype_count["s"] + \ + 2 * orbtype_count["s"] * orbtype_count["p"] + \ + 2 * orbtype_count["s"] * orbtype_count["d"] + \ + 2 * orbtype_count["s"] * orbtype_count["f"] + ) + \ + 2 * ( + 1 * orbtype_count["p"] * orbtype_count["p"] + \ + 2 * orbtype_count["p"] * orbtype_count["d"] + \ + 2 * orbtype_count["p"] * orbtype_count["f"] + ) + \ + 3 * ( + 1 * orbtype_count["d"] * orbtype_count["d"] + \ + 2 * orbtype_count["d"] * orbtype_count["f"] + ) + \ + 4 * (orbtype_count["f"] * orbtype_count["f"]) + + self.reduced_matrix_element = self.reduced_matrix_element + orbtype_count["s"] + 2*orbtype_count["p"] + 3*orbtype_count["d"] + 4*orbtype_count["f"] + self.reduced_matrix_element = int(self.reduced_matrix_element / 2) + self.n_onsite_Es = orbtype_count["s"] + orbtype_count["p"] + orbtype_count["d"] + orbtype_count["f"] + + # sort the basis + for ib in self.basis.keys(): + self.basis[ib] = sorted( + self.basis[ib], + key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0]) + ) + + # TODO: get full basis set + full_basis = [] + for io in ["s", "p", "d", "f"]: + full_basis = full_basis + [str(i)+io for i in range(1, orbtype_count[io]+1)] + self.full_basis = full_basis + + # TODO: get the mapping from list basis to full basis + self.basis_to_full_basis = {} + self.atom_norb = torch.zeros(len(self.type_names), dtype=torch.long, device=self.device) + for ib in self.basis.keys(): + count_dict = {"s":0, "p":0, "d":0, "f":0} + self.basis_to_full_basis.setdefault(ib, {}) + for o in self.basis[ib]: + io = re.findall(r"[a-z]", o)[0] + l = anglrMId[io] + count_dict[io] += 1 + self.atom_norb[self.chemical_symbol_to_type[ib]] += 2*l+1 + + self.basis_to_full_basis[ib][o] = str(count_dict[io])+io + + # get the mapping from full basis to list basis + self.full_basis_to_basis = {} + for at, maps in self.basis_to_full_basis.items(): + self.full_basis_to_basis[at] = {} + for k,v in maps.items(): + self.full_basis_to_basis[at].update({v:k}) + + # Get the mask for mapping from full basis to atom specific basis + self.mask_to_basis = torch.zeros(len(self.type_names), self.full_basis_norb, device=self.device, dtype=torch.bool) + + for ib in self.basis.keys(): + ibasis = list(self.basis_to_full_basis[ib].values()) + ist = 0 + for io in self.full_basis: + l = anglrMId[io[1]] + if io in ibasis: + self.mask_to_basis[self.chemical_symbol_to_type[ib]][ist:ist+2*l+1] = True + + ist += 2*l+1 + + assert (self.mask_to_basis.sum(dim=1).int()-self.atom_norb).abs().sum() <= 1e-6 + + self.get_orbpair_maps() + # the mask to map the full basis reduced matrix element to the original basis reduced matrix element + self.mask_to_erme = torch.zeros(len(self.bond_types), self.reduced_matrix_element, dtype=torch.bool, device=self.device) + self.mask_to_nrme = torch.zeros(len(self.type_names), self.reduced_matrix_element, dtype=torch.bool, device=self.device) + for ib, bb in self.basis.items(): + for io in bb: + iof = self.basis_to_full_basis[ib][io] + for jo in bb: + jof = self.basis_to_full_basis[ib][jo] + if self.orbpair_maps.get(iof+"-"+jof) is not None: + self.mask_to_nrme[self.chemical_symbol_to_type[ib]][self.orbpair_maps[iof+"-"+jof]] = True + + for ib in self.bond_to_type.keys(): + a,b = ib.split("-") + for io in self.basis[a]: + iof = self.basis_to_full_basis[a][io] + for jo in self.basis[b]: + jof = self.basis_to_full_basis[b][jo] + if self.orbpair_maps.get(iof+"-"+jof) is not None: + self.mask_to_erme[self.bond_to_type[ib]][self.orbpair_maps[iof+"-"+jof]] = True + elif self.orbpair_maps.get(jof+"-"+iof) is not None: + self.mask_to_erme[self.bond_to_type[ib]][self.orbpair_maps[jof+"-"+iof]] = True + + def get_orbpairtype_maps(self): + """ + The function `get_orbpairtype_maps` creates a mapping of orbital pair types, such as s-s, "s-p", + to slices based on the number of hops between them. + :return: a dictionary called `pairtype_map`. + """ + + self.orbpairtype_maps = {} + ist = 0 + for i, io in enumerate(["s", "p", "d", "f"]): + if self.orbtype_count[io] != 0: + for jo in ["s", "p", "d", "f"][i:]: + if self.orbtype_count[jo] != 0: + orb_pair = io+"-"+jo + il, jl = anglrMId[io], anglrMId[jo] + if self.method == "e3tb": + n_rme = (2*il+1) * (2*jl+1) + else: + n_rme = min(il, jl)+1 + numhops = self.orbtype_count[io] * self.orbtype_count[jo] * n_rme + if io == jo: + numhops += self.orbtype_count[jo] * n_rme + numhops = int(numhops / 2) + self.orbpairtype_maps[orb_pair] = slice(ist, ist+numhops) + + ist += numhops + + return self.orbpairtype_maps + + def get_orbpair_maps(self): + + if hasattr(self, "orbpair_maps"): + return self.orbpair_maps + + if not hasattr(self, "orbpairtype_maps"): + self.get_orbpairtype_maps() + + self.orbpair_maps = {} + for i, io in enumerate(self.full_basis): + for jo in self.full_basis[i:]: + full_basis_pair = io+"-"+jo + ir, jr = int(full_basis_pair[0]), int(full_basis_pair[3]) + iio, jjo = full_basis_pair[1], full_basis_pair[4] + il, jl = anglrMId[iio], anglrMId[jjo] + + if self.method == "e3tb": + n_feature = (2*il+1) * (2*jl+1) + else: + n_feature = min(il, jl)+1 + if iio == jjo: + start = self.orbpairtype_maps[iio+"-"+jjo].start + \ + n_feature * ((2*self.orbtype_count[jjo]+2-ir) * (ir-1) / 2 + (jr - ir)) + else: + start = self.orbpairtype_maps[iio+"-"+jjo].start + \ + n_feature * ((ir-1)*self.orbtype_count[jjo]+(jr-1)) + start = int(start) + self.orbpair_maps[io+"-"+jo] = slice(start, start+n_feature) + + return self.orbpair_maps + + def get_skonsite_maps(self): + + assert self.method == "sktb", "Only sktb orbitalmapper have skonsite maps." + + if hasattr(self, "skonsite_maps"): + return self.skonsite_maps + + if not hasattr(self, "skonsitetype_maps"): + self.get_skonsitetype_maps() + + self.skonsite_maps = {} + for i, io in enumerate(self.full_basis): + ir= int(io[0]) + iio = io[1] + + start = int(self.skonsitetype_maps[iio].start + (ir-1)) + self.skonsite_maps[io] = slice(start, start+1) + + return self.skonsite_maps + + def get_skonsitetype_maps(self): + self.skonsitetype_maps = {} + ist = 0 + + assert self.method == "sktb", "Only sktb orbitalmapper have skonsite maps." + for i, io in enumerate(["s", "p", "d", "f"]): + if self.orbtype_count[io] != 0: + il = anglrMId[io] + numonsites = self.orbtype_count[io] + + self.skonsitetype_maps[io] = slice(ist, ist+numonsites) + + ist += numonsites + + return self.skonsitetype_maps + + # also need to think if we modify as this, how can we add extra basis when fitting. + + def get_orbital_maps(self): + # simply get a 1-d slice for each atom species. + + self.orbital_maps = {} + self.norbs = {} + + for ib in self.basis.keys(): + orbital_list = self.basis[ib] + slices = {} + start_index = 0 + + self.norbs.setdefault(ib, 0) + for orb in orbital_list: + orb_l = re.findall(r'[A-Za-z]', orb)[0] + increment = (2*anglrMId[orb_l]+1) + self.norbs[ib] += increment + end_index = start_index + increment + + slices[orb] = slice(start_index, end_index) + start_index = end_index + + self.orbital_maps[ib] = slices + + return self.orbital_maps + + def get_irreps(self, no_parity=False): + assert self.method == "e3tb", "Only support e3tb method for now." + + if hasattr(self, "orbpair_irreps"): + if self.no_parity == no_parity: + return self.orbpair_irreps + + self.no_parity = no_parity + + if not hasattr(self, "orbpairtype_maps"): + self.get_orbpairtype_maps() + + irs = [] + if no_parity: + factor = 1 + else: + factor = -1 + + irs = [] + for pair, sli in self.orbpairtype_maps.items(): + l1, l2 = anglrMId[pair[0]], anglrMId[pair[2]] + p = factor**(l1+l2) + required_ls = range(abs(l1 - l2), l1 + l2 + 1) + required_irreps = [(1,(l, p)) for l in required_ls] + irs += required_irreps*int((sli.stop-sli.start)/(2*l1+1)/(2*l2+1)) + + self.orbpair_irreps = o3.Irreps(irs) + return self.orbpair_irreps + + def __eq__(self, other): + return self.basis == other.basis and self.method == other.method \ No newline at end of file diff --git a/dptb/data/use_data.ipynb b/dptb/data/use_data.ipynb new file mode 100644 index 00000000..1d9c6584 --- /dev/null +++ b/dptb/data/use_data.ipynb @@ -0,0 +1,499 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from build import dataset_from_config\n", + "from dptb.utils.config import Config" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"root\": \"/root/nequip_data/\",\n", + " \"dataset\": \"npz\",\n", + " \"dataset_file_name\": \"/root/nequip_data/Si8-100K.npz\",\n", + " \"key_mapping\":{\n", + " \"pos\":\"pos\",\n", + " \"atomic_numbers\":\"atomic_numbers\",\n", + " \"kpoints\": \"kpoint\",\n", + " \"pbc\": \"pbc\",\n", + " \"cell\": \"cell\",\n", + " \"eigenvalues\": \"eigenvalue\"\n", + " },\n", + " \"npz_fixed_field_keys\": [\"kpoint\", \"pbc\"],\n", + " \"graph_field\":[\"eigenvalues\"],\n", + " \"chemical_symbols\": [\"Si\", \"C\"],\n", + " \"r_max\": 6.0\n", + "}\n", + "\n", + "config = Config(config=config)\n", + "# dataset: npz # type of data set, can be npz or ase\n", + "# dataset_url: http://quantum-machine.org/gdml/data/npz/toluene_ccsd_t.zip # url to download the npz. optional\n", + "# dataset_file_name: ./benchmark_data/toluene_ccsd_t-train.npz # path to data set file\n", + "# key_mapping:\n", + "# z: atomic_numbers # atomic species, integers\n", + "# E: total_energy # total potential eneriges to train to\n", + "# F: forces # atomic forces to train to\n", + "# R: pos # raw atomic positions\n", + "# npz_fixed_field_keys: # fields that are repeated across different examples\n", + "# - atomic_numbers\n", + "\n", + "# chemical_symbols:\n", + "# - H\n", + "# - C" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing dataset...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Done!\n" + ] + } + ], + "source": [ + "dataset = dataset_from_config(config=config, prefix=\"dataset\")\n", + "\n", + "from dptb.data.dataloader import DataLoader\n", + "\n", + "dl = DataLoader(dataset, 3)\n", + "\n", + "data = next(iter(dl))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 1., 1., -1.],\n", + " [ 1., 1., 1.],\n", + " [ 0., 1., -1.],\n", + " [ 0., 1., 1.],\n", + " [ 1., 0., -1.],\n", + " [ 0., 0., -1.],\n", + " [ 1., 0., 1.],\n", + " [ 0., 0., 1.],\n", + " [ 0., 1., 0.],\n", + " [ 1., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 1., 1., 0.]])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "dataset[0].edge_cell_shift[dataset[0].edge_index[0].eq(1)&dataset[0].edge_index[1].eq(2)], dataset[0].edge_cell_shift[dataset[0].edge_index[0].eq(1)&dataset[0].edge_index[1].eq(2)]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, False, True, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, False, True, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, True, False, True,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, True, False, False,\n", + " False, True, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, True, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, True, False, False, False, True, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, True, False, False,\n", + " False, False, False, True, False, False, False, False, False, True,\n", + " False, True, True, True])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0].edge_index[0].eq(dataset[0].edge_index[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'C-C': 0, 'C-Si': 1, 'Si-C': 2, 'Si-Si': 3}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.type_mapper.bond_to_type" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from dptb.nn._sktb import SKTB\n", + "sktb = SKTB(\n", + " basis={\"Si\":[\"3s\", \"3p\", \"p*\", \"s*\"], \"C\":[\"2s\",\"2p\"]},\n", + " onsite=\"uniform\",\n", + " hopping=\"powerlaw\",\n", + " overlap=True\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from dptb.data.AtomicDataDict import with_edge_vectors, with_onsitenv_vectors\n", + "\n", + "data = with_edge_vectors(data.to_dict())\n", + "data = with_onsitenv_vectors(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "data[\"atomic_numbers\"] = dataset.type_mapper.untransform(data[\"atom_types\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "data = sktb(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "20" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sktb.idp.edge_reduced_matrix_element" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([24, 4])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"node_features\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from dptb.nn._hamiltonian import SKHamiltonian\n", + "\n", + "skh = SKHamiltonian(basis={\"Si\":[\"3s\", \"3p\", \"p*\", \"s*\"], \"C\":[\"2s\",\"2p\"]})" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "data = skh(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([24, 42])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"node_features\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from dptb.nn._hamiltonian import E3Hamiltonian\n", + "e3h = E3Hamiltonian(basis={\"Si\":[\"3s\", \"3p\", \"p*\", \"s*\"], \"C\":[\"2s\",\"2p\"]}, decompose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "data = e3h(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ True, True, True, True, False, True, False, False, True, False,\n", + " False, True, False, False, True, False, False, True, False, False,\n", + " True, False, False, True, False, False, True, False, True, False,\n", + " False, False, False, False, True, False, False, True, False, False,\n", + " False, False, False, True, False, False, True, False, False, False,\n", + " False, False, True, False, False, True, False, False, False, False,\n", + " False, True, False, False])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"edge_features\"][0].abs().gt(1e-5)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "from dptb.data.AtomicData import AtomicData\n", + "from dptb.utils.torch_geometric import Batch\n", + "\n", + "bdata = Batch.from_dict(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Cannot reconstruct data list from batch because the batch object was not created using `Batch.from_data_list()`.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/root/deeptb/dptb/data/use_data.ipynb Cell 13\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m bdata\u001b[39m.\u001b[39;49mget_example(\u001b[39m0\u001b[39;49m)\n", + "File \u001b[0;32m/opt/miniconda/envs/deeptb/lib/python3.8/site-packages/dptb/utils/torch_geometric/batch.py:176\u001b[0m, in \u001b[0;36mBatch.get_example\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Reconstructs the :class:`torch_geometric.data.Data` object at index\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[39m:obj:`idx` from the batch object.\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[39mThe batch object must have been created via :meth:`from_data_list` in\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[39morder to be able to reconstruct the initial objects.\"\"\"\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__slices__ \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 176\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 177\u001b[0m (\n\u001b[1;32m 178\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mCannot reconstruct data list from batch because the batch \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 179\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mobject was not created using `Batch.from_data_list()`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 180\u001b[0m )\n\u001b[1;32m 181\u001b[0m )\n\u001b[1;32m 183\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__data_class__()\n\u001b[1;32m 184\u001b[0m idx \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_graphs \u001b[39m+\u001b[39m idx \u001b[39mif\u001b[39;00m idx \u001b[39m<\u001b[39m \u001b[39m0\u001b[39m \u001b[39melse\u001b[39;00m idx\n", + "\u001b[0;31mRuntimeError\u001b[0m: Cannot reconstruct data list from batch because the batch object was not created using `Batch.from_data_list()`." + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from dptb.data.transforms import OrbitalMapper\n", + "\n", + "idp = OrbitalMapper(basis={\"Si\": \"2s2p1d\", \"C\":\"1s1p1d\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'1s-1s': slice(0, 1, None),\n", + " '1s-2s': slice(1, 2, None),\n", + " '1s-1p': slice(3, 6, None),\n", + " '1s-2p': slice(6, 9, None),\n", + " '1s-1d': slice(15, 20, None),\n", + " '2s-2s': slice(2, 3, None),\n", + " '2s-1p': slice(9, 12, None),\n", + " '2s-2p': slice(12, 15, None),\n", + " '2s-1d': slice(20, 25, None),\n", + " '1p-1p': slice(25, 34, None),\n", + " '1p-2p': slice(34, 43, None),\n", + " '1p-1d': slice(52, 67, None),\n", + " '2p-2p': slice(43, 52, None),\n", + " '2p-1d': slice(67, 82, None),\n", + " '1d-1d': slice(82, 107, None)}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idp.get_node_maps()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'1s-1s': slice(0, 1, None),\n", + " '1s-2s': slice(1, 2, None),\n", + " '1s-1p': slice(3, 6, None),\n", + " '1s-2p': slice(6, 9, None),\n", + " '1s-1d': slice(15, 20, None),\n", + " '2s-2s': slice(2, 3, None),\n", + " '2s-1p': slice(9, 12, None),\n", + " '2s-2p': slice(12, 15, None),\n", + " '2s-1d': slice(20, 25, None),\n", + " '1p-1p': slice(25, 34, None),\n", + " '1p-2p': slice(34, 43, None),\n", + " '1p-1d': slice(52, 67, None),\n", + " '2p-2p': slice(43, 52, None),\n", + " '2p-1d': slice(67, 82, None),\n", + " '1d-1d': slice(82, 107, None)}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idp.node_maps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deeptb", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dptb/data/util.py b/dptb/data/util.py new file mode 100644 index 00000000..494c2f8f --- /dev/null +++ b/dptb/data/util.py @@ -0,0 +1,4 @@ +import torch + +# There is no built-in way to check if a Tensor is of an integer type +_TORCH_INTEGER_DTYPES = (torch.int, torch.long) diff --git a/dptb/entrypoints/__init__.py b/dptb/entrypoints/__init__.py index 4e3c5d99..b921c542 100644 --- a/dptb/entrypoints/__init__.py +++ b/dptb/entrypoints/__init__.py @@ -1,5 +1,5 @@ from dptb.entrypoints.train import train from dptb.entrypoints.config import config -from dptb.entrypoints.run import run +# from dptb.entrypoints.run import run from dptb.entrypoints.test import _test as test from dptb.entrypoints.bond import bond \ No newline at end of file diff --git a/dptb/entrypoints/data.py b/dptb/entrypoints/data.py new file mode 100644 index 00000000..d19f25fb --- /dev/null +++ b/dptb/entrypoints/data.py @@ -0,0 +1,39 @@ +import os +from typing import Dict, List, Optional, Any +from dptb.utils.tools import j_loader +from dptb.utils.argcheck import normalize +from dptb.data.interfaces.abacus import recursive_parse + +def data( + INPUT: str, + log_level: int, + log_path: Optional[str], + **kwargs +): + jdata = j_loader(INPUT) + + # ABACUS parsing input like: + # { "type": "ABACUS", + # "parse_arguments": { + # "input_path": "alice_*/*_bob/system_No_*", + # "preprocess_dir": "charlie/david", + # "only_overlap": false, + # "get_Hamiltonian": true, + # "add_overlap": true, + # "get_eigenvalues": true } } + if jdata["type"] == "ABACUS": + abacus_args = jdata["parse_arguments"] + assert abacus_args.get("input_path") is not None, "ABACUS calculation results MUST be provided." + assert abacus_args.get("preprocess_dir") is not None, "Please assign a dictionary to store preprocess files." + + print("Begin parsing ABACUS output...") + recursive_parse(**abacus_args) + print("Finished parsing ABACUS output.") + + ## write all h5 files to be used in building AtomicData + #with open(os.path.join(abacus_args["preprocess_dir"], "AtomicData_file.txt"), "w") as f: + # for filename in h5_filenames: + # f.write(filename + "\n") + + else: + raise Exception("Not supported software output.") \ No newline at end of file diff --git a/dptb/entrypoints/main.py b/dptb/entrypoints/main.py index 226bf685..20def5a4 100644 --- a/dptb/entrypoints/main.py +++ b/dptb/entrypoints/main.py @@ -8,6 +8,7 @@ from dptb.entrypoints.run import run from dptb.entrypoints.bond import bond from dptb.entrypoints.nrl2json import nrl2json +from dptb.entrypoints.data import data from dptb.utils.loggers import set_log_handles def get_ll(log_level: str) -> int: @@ -168,13 +169,6 @@ def main_parser() -> argparse.ArgumentParser: help="Restart the training from the provided checkpoint.", ) - parser_train.add_argument( - "-sk", - "--train-sk", - action="store_true", - help="Trainging NNSKTB parameters.", - ) - parser_train.add_argument( "-crt", "--use-correction", @@ -190,13 +184,6 @@ def main_parser() -> argparse.ArgumentParser: help="Initialize the training from the frozen model.", ) - parser_train.add_argument( - "-f", - "--freeze", - action="store_true", - help="Initialize the training from the frozen model.", - ) - parser_train.add_argument( "-o", "--output", @@ -226,13 +213,6 @@ def main_parser() -> argparse.ArgumentParser: help="Initialize the model by the provided checkpoint.", ) - parser_test.add_argument( - "-sk", - "--test-sk", - action="store_true", - help="Test NNSKTB parameters.", - ) - parser_test.add_argument( "-crt", "--use-correction", @@ -287,13 +267,6 @@ def main_parser() -> argparse.ArgumentParser: help="The output files in postprocess run." ) - parser_run.add_argument( - "-sk", - "--run_sk", - action="store_true", - help="using NNSKTB parameters TB models for post-run." - ) - parser_run.add_argument( "-crt", "--use-correction", @@ -302,6 +275,20 @@ def main_parser() -> argparse.ArgumentParser: help="Use nnsktb correction when training dptb", ) + # preprocess data + parser_data = subparsers.add_parser( + "data", + parents=[parser_log], + help="preprocess software output", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser_data.add_argument( + "INPUT", help="the input parameter file in json or yaml format", + type=str, + default=None + ) + return parser def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: @@ -352,3 +339,5 @@ def main(): elif args.command == 'n2j': nrl2json(**dict_args) + elif args.command == 'data': + data(**dict_args) diff --git a/dptb/entrypoints/run.py b/dptb/entrypoints/run.py index afc7ccf5..f09d3ee2 100644 --- a/dptb/entrypoints/run.py +++ b/dptb/entrypoints/run.py @@ -1,232 +1,156 @@ -import logging -import json -import os -import struct -import time -import torch -from pathlib import Path -from typing import Dict, List, Optional, Any -from dptb.plugins.train_logger import Logger -from dptb.plugins.init_nnsk import InitSKModel -from dptb.plugins.init_dptb import InitDPTBModel -from dptb.utils.argcheck import normalize, normalize_run -from dptb.utils.tools import j_loader -from dptb.utils.loggers import set_log_handles -from dptb.utils.tools import j_must_have -from dptb.utils.constants import dtype_dict -from dptb.nnops.apihost import NNSKHost, DPTBHost -from dptb.nnops.NN2HRK import NN2HRK -from ase.io import read,write -from dptb.postprocess.bandstructure.band import bandcalc -from dptb.postprocess.bandstructure.dos import doscalc, pdoscalc -from dptb.postprocess.bandstructure.fermisurface import fs2dcalc, fs3dcalc -from dptb.postprocess.bandstructure.ifermi_api import ifermiapi, ifermi_installed, pymatgen_installed -from dptb.postprocess.write_skparam import WriteNNSKParam -from dptb.postprocess.NEGF import NEGF -from dptb.postprocess.tbtrans_init import TBTransInputSet,sisl_installed - -__all__ = ["run"] - -log = logging.getLogger(__name__) - -def run( - INPUT: str, - init_model: str, - output: str, - run_sk: bool, - structure: str, - log_level: int, - log_path: Optional[str], - use_correction: Optional[str], - **kwargs - ): +# import logging +# import json +# import os +# import time +# import torch +# from pathlib import Path +# from typing import Optional +# from dptb.plugins.train_logger import Logger +# from dptb.utils.argcheck import normalize_run +# from dptb.utils.tools import j_loader +# from dptb.utils.loggers import set_log_handles +# from dptb.utils.tools import j_must_have +# from dptb.nn.build import build_model +# from dptb.postprocess.bandstructure.band import bandcalc +# from dptb.postprocess.bandstructure.dos import doscalc, pdoscalc +# from dptb.postprocess.bandstructure.fermisurface import fs2dcalc, fs3dcalc +# from dptb.postprocess.bandstructure.ifermi_api import ifermiapi, ifermi_installed, pymatgen_installed +# from dptb.postprocess.write_skparam import WriteNNSKParam +# from dptb.postprocess.NEGF import NEGF +# from dptb.postprocess.tbtrans_init import TBTransInputSet,sisl_installed + +# __all__ = ["run"] + +# log = logging.getLogger(__name__) + +# def run( +# INPUT: str, +# init_model: str, +# output: str, +# structure: str, +# log_level: int, +# log_path: Optional[str], +# **kwargs +# ): - run_opt = { - "run_sk": run_sk, - "init_model":init_model, - "structure":structure, - "log_path": log_path, - "log_level": log_level, - "use_correction":use_correction - } - - if all((use_correction, run_sk)): - log.error(msg="--use-correction and --train_sk should not be set at the same time") - raise RuntimeError +# run_opt = { +# "init_model":init_model, +# "structure":structure, +# "log_path": log_path, +# "log_level": log_level, +# } + +# # output folder. +# if output: + +# Path(output).parent.mkdir(exist_ok=True, parents=True) +# Path(output).mkdir(exist_ok=True, parents=True) +# results_path = os.path.join(str(output), "results") +# Path(results_path).mkdir(exist_ok=True, parents=True) +# if not log_path: +# log_path = os.path.join(str(output), "log/log.txt") +# Path(log_path).parent.mkdir(exist_ok=True, parents=True) + +# run_opt.update({ +# "output": str(Path(output).absolute()), +# "results_path": str(Path(results_path).absolute()), +# "log_path": str(Path(log_path).absolute()) +# }) - jdata = j_loader(INPUT) - jdata = normalize_run(jdata) +# jdata = j_loader(INPUT) +# jdata = normalize_run(jdata) - if all((jdata["init_model"]["path"], run_opt["init_model"])): - raise RuntimeError( - "init-model in config and command line is in conflict, turn off one of then to avoid this error !" - ) - - if run_opt["init_model"] is None: - log.info(msg="model_ckpt is not set in command line, read from input config file.") - - if run_sk: - if jdata["init_model"]["path"] is not None: - run_opt["init_model"] = jdata["init_model"] - else: - log.error(msg="Error! init_model is not set in config file and command line.") - raise RuntimeError - if isinstance(run_opt["init_model"]["path"], list): - if len(run_opt["init_model"]["path"])==0: - log.error("Error! list mode init_model in config file cannot be empty!") - raise RuntimeError - else: - if jdata["init_model"]["path"] is not None: - run_opt["init_model"] = jdata["init_model"] - else: - log.error(msg="Error! init_model is not set in config file and command line.") - raise RuntimeError - if isinstance(run_opt["init_model"]["path"], list): - raise RuntimeError( - "loading lists of checkpoints is only supported in init_nnsk!" - ) - if isinstance(run_opt["init_model"]["path"], list): - if len(run_opt["init_model"]["path"]) == 1: - run_opt["init_model"]["path"] = run_opt["init_model"]["path"][0] - else: - path = run_opt["init_model"] - run_opt["init_model"] = jdata["init_model"] - run_opt["init_model"]["path"] = path +# set_log_handles(log_level, Path(log_path) if log_path else None) + +# f = torch.load(run_opt["init_model"]) +# jdata["model_options"] = f["config"]["model_options"] +# del f - task_options = j_must_have(jdata, "task_options") - task = task_options["task"] - use_gui = jdata.get("use_gui", False) - task_options.update({"use_gui": use_gui}) - - model_ckpt = run_opt["init_model"]["path"] - # init_type = model_ckpt.split(".")[-1] - # if init_type not in ["json", "pth"]: - # log.error(msg="Error! the model file should be a json or pth file.") - # raise RuntimeError - - # if init_type == "json": - # jdata = host_normalize(jdata) - # if run_sk: - # jdata.update({"init_model": {"path": model_ckpt,"interpolate": False}}) - # else: - # jdata.update({"init_model": model_ckpt}) - - if run_opt['structure'] is None: - log.warning(msg="Warning! structure is not set in run option, read from input config file.") - structure = j_must_have(jdata, "structure") - run_opt.update({"structure":structure}) - - print(run_opt["structure"]) - - if not run_sk: - if run_opt['use_correction'] is None and jdata.get('use_correction',None) != None: - use_correction = jdata['use_correction'] - run_opt.update({"use_correction":use_correction}) - log.info(msg="use_correction is not set in run option, read from input config file.") - - # output folder. - if output: - - Path(output).parent.mkdir(exist_ok=True, parents=True) - Path(output).mkdir(exist_ok=True, parents=True) - results_path = os.path.join(str(output), "results") - Path(results_path).mkdir(exist_ok=True, parents=True) - if not log_path: - log_path = os.path.join(str(output), "log/log.txt") - Path(log_path).parent.mkdir(exist_ok=True, parents=True) - - run_opt.update({ - "output": str(Path(output).absolute()), - "results_path": str(Path(results_path).absolute()), - "log_path": str(Path(log_path).absolute()) - }) - - set_log_handles(log_level, Path(log_path) if log_path else None) - - if jdata.get("common_options", None): - # in this case jdata must have common options - str_dtype = jdata["common_options"]["dtype"] - # jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]] - - if run_sk: - apihost = NNSKHost(checkpoint=model_ckpt, config=jdata) - apihost.register_plugin(InitSKModel()) - apihost.build() - apiHrk = NN2HRK(apihost=apihost, mode='nnsk') - else: - apihost = DPTBHost(dptbmodel=model_ckpt,use_correction=use_correction) - apihost.register_plugin(InitDPTBModel()) - apihost.build() - apiHrk = NN2HRK(apihost=apihost, mode='dptb') +# task_options = j_must_have(jdata, "task_options") +# task = task_options["task"] +# use_gui = jdata.get("use_gui", False) +# task_options.update({"use_gui": use_gui}) + +# if run_opt['structure'] is None: +# log.warning(msg="Warning! structure is not set in run option, read from input config file.") +# structure = j_must_have(jdata, "structure") +# run_opt.update({"structure":structure}) +# else: +# structure = run_opt["structure"] + +# print(run_opt["structure"]) + +# if jdata.get("common_options", None): +# # in this case jdata must have common options +# str_dtype = jdata["common_options"]["dtype"] +# # jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]] + +# model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"]) - - # one can just add his own function to calculate properties by add a task, and its code to calculate. - - if task=='band': - # TODO: add argcheck for bandstructure, with different options. see, kline_mode: ase, vasp, abacus, etc. - bcal = bandcalc(apiHrk, run_opt, task_options) - bcal.get_bands() - bcal.band_plot() - log.info(msg='band calculation successfully completed.') - - if task=='dos': - bcal = doscalc(apiHrk, run_opt, task_options) - bcal.get_dos() - bcal.dos_plot() - log.info(msg='dos calculation successfully completed.') - - if task=='pdos': - bcal = pdoscalc(apiHrk, run_opt, task_options) - bcal.get_pdos() - bcal.pdos_plot() - log.info(msg='pdos calculation successfully completed.') +# # one can just add his own function to calculate properties by add a task, and its code to calculate. +# if task=='band': +# # TODO: add argcheck for bandstructure, with different options. see, kline_mode: ase, vasp, abacus, etc. +# bcal = bandcalc(model, structure, task_options) +# bcal.get_bands() +# bcal.band_plot() +# log.info(msg='band calculation successfully completed.') + +# if task=='dos': +# bcal = doscalc(model, structure, task_options) +# bcal.get_dos() +# bcal.dos_plot() +# log.info(msg='dos calculation successfully completed.') + +# if task=='pdos': +# bcal = pdoscalc(model, structure, task_options) +# bcal.get_pdos() +# bcal.pdos_plot() +# log.info(msg='pdos calculation successfully completed.') - if task=='FS2D': - fs2dcal = fs2dcalc(apiHrk, run_opt, task_options) - fs2dcal.get_fs() - fs2dcal.fs2d_plot() - log.info(msg='2dFS calculation successfully completed.') +# if task=='FS2D': +# fs2dcal = fs2dcalc(model, structure, task_options) +# fs2dcal.get_fs() +# fs2dcal.fs2d_plot() +# log.info(msg='2dFS calculation successfully completed.') - if task == 'FS3D': - fs3dcal = fs3dcalc(apiHrk, run_opt, task_options) - fs3dcal.get_fs() - fs3dcal.fs_plot() - log.info(msg='3dFS calculation successfully completed.') +# if task == 'FS3D': +# fs3dcal = fs3dcalc(model, structure, task_options) +# fs3dcal.get_fs() +# fs3dcal.fs_plot() +# log.info(msg='3dFS calculation successfully completed.') - if task == 'ifermi': - if not(ifermi_installed and pymatgen_installed): - log.error(msg="ifermi and pymatgen are required to perform ifermi calculation !") - raise RuntimeError - - ifermi = ifermiapi(apiHrk, run_opt, task_options) - bs = ifermi.get_band_structure() - fs = ifermi.get_fs(bs) - ifermi.fs_plot(fs) - log.info(msg='Ifermi calculation successfully completed.') - if task == 'write_sk': - if not run_sk: - raise RuntimeError("write_sk can only perform on nnsk model !") - write_sk = WriteNNSKParam(apiHrk, run_opt, task_options) - write_sk.write() - log.info(msg='write_sk calculation successfully completed.') - - if task == 'negf': - negf = NEGF(apiHrk, run_opt, task_options) - negf.compute() - log.info(msg='NEGF calculation successfully completed.') - - if task == 'tbtrans_negf': - if not(sisl_installed): - log.error(msg="sisl is required to perform tbtrans calculation !") - raise RuntimeError - - tbtrans_init = TBTransInputSet(apiHrk, run_opt, task_options) - tbtrans_init.hamil_get_write(write_nc=True) - log.info(msg='TBtrans input files are successfully generated.') - - if output: - with open(os.path.join(output, "run_config.json"), "w") as fp: - if jdata.get("common_options", None): - jdata["common_options"]["dtype"] = str_dtype - json.dump(jdata, fp, indent=4) +# if task == 'ifermi': +# if not(ifermi_installed and pymatgen_installed): +# log.error(msg="ifermi and pymatgen are required to perform ifermi calculation !") +# raise RuntimeError + +# ifermi = ifermiapi(model, structure, task_options) +# bs = ifermi.get_band_structure() +# fs = ifermi.get_fs(bs) +# ifermi.fs_plot(fs) +# log.info(msg='Ifermi calculation successfully completed.') +# if task == 'write_sk': +# if not jdata["model_options"].keys()[0] == "nnsk" or len(jdata["model_options"].keys()) > 1: +# raise RuntimeError("write_sk can only perform on nnsk model !") +# write_sk = WriteNNSKParam(model, structure, task_options) +# write_sk.write() +# log.info(msg='write_sk calculation successfully completed.') + +# if task == 'negf': +# negf = NEGF(model, structure, task_options) +# negf.compute() +# log.info(msg='NEGF calculation successfully completed.') + +# if task == 'tbtrans_negf': +# if not(sisl_installed): +# log.error(msg="sisl is required to perform tbtrans calculation !") +# raise RuntimeError + +# tbtrans_init = TBTransInputSet(apiHrk, run_opt, task_options) +# tbtrans_init.hamil_get_write(write_nc=True) +# log.info(msg='TBtrans input files are successfully generated.') + +# if output: +# with open(os.path.join(output, "run_config.json"), "w") as fp: +# json.dump(jdata, fp, indent=4) diff --git a/dptb/entrypoints/test.py b/dptb/entrypoints/test.py index 7380ef8c..9874f7b3 100644 --- a/dptb/entrypoints/test.py +++ b/dptb/entrypoints/test.py @@ -1,21 +1,16 @@ import heapq import logging import torch -import random import json import os import time -import numpy as np from pathlib import Path -from dptb.nnops.tester_dptb import DPTBTester -from dptb.nnops.tester_nnsk import NNSKTester -from typing import Dict, List, Optional, Any +from dptb.nn.build import build_model +from dptb.data.build import build_dataset +from typing import Optional from dptb.utils.loggers import set_log_handles from dptb.utils.tools import j_loader, setup_seed -from dptb.utils.constants import dtype_dict -from dptb.plugins.init_nnsk import InitSKModel -from dptb.plugins.init_dptb import InitDPTBModel -from dptb.plugins.init_data import InitTestData +from dptb.nnops.tester import Tester from dptb.utils.argcheck import normalize_test from dptb.plugins.monitor import TestLossMonitor from dptb.plugins.train_logger import Logger @@ -30,7 +25,6 @@ def _test( output: str, log_level: int, log_path: Optional[str], - test_sk: bool, use_correction: Optional[str], **kwargs ): @@ -39,98 +33,11 @@ def _test( "init_model": init_model, "log_path": log_path, "log_level": log_level, - "test_sk": test_sk, "use_correction": use_correction, "freeze":True, "train_soc":False } - if all((use_correction, test_sk)): - log.error(msg="--use-correction and --train_sk should not be set at the same time") - raise RuntimeError - - # setup INPUT path - if test_sk: - if init_model: - skconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_nnsktb.json") - mode = "init_model" - else: - log.error("ValueError: Missing init_model file path.") - raise ValueError - jdata = j_loader(INPUT) - jdata = normalize_test(jdata) - - if all((jdata["init_model"]["path"], run_opt["init_model"])): - raise RuntimeError( - "init-model in config and command line is in conflict, turn off one of then to avoid this error !" - ) - else: - if jdata["init_model"]["path"] is not None: - assert mode == "from_scratch" - run_opt["init_model"] = jdata["init_model"] - mode = "init_model" - if isinstance(run_opt["init_model"]["path"], str): - skconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_nnsktb.json") - else: # list - skconfig_path = [os.path.join(str(Path(path).parent.absolute()), "config_nnsktb.json") for path in run_opt["init_model"]["path"]] - else: - if run_opt["init_model"] is not None: - assert mode == "init_model" - path = run_opt["init_model"] - run_opt["init_model"] = jdata["init_model"] - run_opt["init_model"]["path"] = path - else: - if init_model: - dptbconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_dptbtb.json") - mode = "init_model" - else: - log.error("ValueError: Missing init_model file path.") - raise ValueError - - if use_correction: - skconfig_path = os.path.join(str(Path(use_correction).parent.absolute()), "config_nnsktb.json") - else: - skconfig_path = None - - jdata = j_loader(INPUT) - jdata = normalize_test(jdata) - - if all((jdata["init_model"]["path"], run_opt["init_model"])): - raise RuntimeError( - "init-model in config and command line is in conflict, turn off one of then to avoid this error !" - ) - - if jdata["init_model"]["path"] is not None: - assert mode == "from_scratch" - log.info(msg="Init model is read from config rile.") - run_opt["init_model"] = jdata["init_model"] - mode = "init_model" - if isinstance(run_opt["init_model"]["path"], str): - dptbconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_dptb.json") - else: # list - raise RuntimeError( - "loading lists of checkpoints is only supported in init_nnsk!" - ) - elif run_opt["init_model"] is not None: - assert mode == "init_model" - path = run_opt["init_model"] - run_opt["init_model"] = jdata["init_model"] - run_opt["init_model"]["path"] = path - - if mode == "init_model": - if isinstance(run_opt["init_model"]["path"], list): - if len(run_opt["init_model"]["path"])==0: - log.error(msg="Error, no checkpoint supplied!") - raise RuntimeError - elif len(run_opt["init_model"]["path"])>1: - log.error(msg="Error! list mode init_model in config only support single file in DPTB!") - raise RuntimeError - - if mode == "init_model": - if isinstance(run_opt["init_model"]["path"], list): - if len(run_opt["init_model"]["path"]) == 1: - run_opt["init_model"]["path"] = run_opt["init_model"]["path"][0] - # setup output path if output: Path(output).parent.mkdir(exist_ok=True, parents=True) @@ -146,42 +53,38 @@ def _test( "results_path": str(Path(results_path).absolute()), "log_path": str(Path(log_path).absolute()) }) - run_opt.update({"mode": mode}) - if test_sk: - run_opt.update({ - "skconfig_path": skconfig_path, - }) - else: - if use_correction: - run_opt.update({ - "skconfig_path": skconfig_path - }) - run_opt.update({ - "dptbconfig_path": dptbconfig_path - }) - - set_log_handles(log_level, Path(log_path) if log_path else None) + + jdata = j_loader(INPUT) + jdata = normalize_test(jdata) + # setup seed + setup_seed(seed=jdata["common_options"]["seed"]) - # setup_seed(seed=jdata["train_options"]["seed"]) + f = torch.load(init_model) + # update basis + basis = f["config"]["common_options"]["basis"] + for asym, orb in jdata["common_options"]["basis"].items(): + assert asym in basis.keys(), f"Atom {asym} not found in model's basis" + assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis" + jdata["common_options"]["basis"] = basis # use the old basis, because it will be used to build the orbital mapper for dataset - # with open(os.path.join(output, "test_config.json"), "w") as fp: - # json.dump(jdata, fp, indent=4) - + set_log_handles(log_level, Path(log_path) if log_path else None) - str_dtype = jdata["common_options"]["dtype"] - jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]] + f = torch.load(run_opt["init_model"]) + jdata["model_options"] = f["config"]["model_options"] + del f + test_datasets = build_dataset(set_options=jdata["data_options"]["test"], common_options=jdata["common_options"]) + model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"]) + model.eval() + tester = Tester( + test_options=jdata["test_options"], + common_options=jdata["common_options"], + model = model, + test_datasets=test_datasets, + ) - if test_sk: - tester = NNSKTester(run_opt, jdata) - tester.register_plugin(InitSKModel()) - else: - tester = DPTBTester(run_opt, jdata) - tester.register_plugin(InitDPTBModel()) - # register the plugin in tester, to tract training info - tester.register_plugin(InitTestData()) tester.register_plugin(TestLossMonitor()) tester.register_plugin(Logger(["test_loss"], interval=[(1, 'iteration'), (1, 'epoch')])) @@ -194,20 +97,12 @@ def _test( if output: # output training configurations: with open(os.path.join(output, "test_config.json"), "w") as fp: - jdata["common_options"]["dtype"] = str_dtype json.dump(jdata, fp, indent=4) - #tester.register_plugin(Saver( - #interval=[(jdata["train_options"].get("save_freq"), 'epoch'), (1, 'iteration')] if jdata["train_options"].get( - # "save_freq") else None)) - # interval=[(jdata["train_options"].get("save_freq"), 'iteration'), (1, 'epoch')] if jdata["train_options"].get( - # "save_freq") else None)) - # add a plugin to save the training parameters of the model, with model_output as given path - start_time = time.time() - tester.run(epochs=1) + tester.run() end_time = time.time() log.info("finished testing") - log.info(f"wall time: {(end_time - start_time):.3f} s") + log.info(f"wall time: {(end_time - start_time):.3f} s") \ No newline at end of file diff --git a/dptb/entrypoints/train.py b/dptb/entrypoints/train.py index b1be08a3..d58b6710 100644 --- a/dptb/entrypoints/train.py +++ b/dptb/entrypoints/train.py @@ -1,14 +1,12 @@ -from dptb.nnops.train_dptb import DPTBTrainer -from dptb.nnops.train_nnsk import NNSKTrainer +from dptb.nnops.trainer import Trainer +from dptb.nn.build import build_model +from dptb.data.build import build_dataset from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer -from dptb.plugins.init_nnsk import InitSKModel -from dptb.plugins.init_dptb import InitDPTBModel -from dptb.plugins.init_data import InitData from dptb.plugins.train_logger import Logger from dptb.utils.argcheck import normalize from dptb.plugins.plugins import Saver from typing import Dict, List, Optional, Any -from dptb.utils.tools import j_loader, setup_seed +from dptb.utils.tools import j_loader, setup_seed, j_must_have from dptb.utils.constants import dtype_dict from dptb.utils.loggers import set_log_handles import heapq @@ -30,30 +28,24 @@ def train( INPUT: str, init_model: Optional[str], restart: Optional[str], - freeze:bool, train_soc:bool, output: str, log_level: int, log_path: Optional[str], - train_sk: bool, - use_correction: Optional[str], **kwargs ): run_opt = { "init_model": init_model, "restart": restart, - "freeze": freeze, "train_soc": train_soc, "log_path": log_path, - "log_level": log_level, - "train_sk": train_sk, - "use_correction": use_correction + "log_level": log_level } + assert train_soc is False, "train_soc is not supported yet" + ''' -1- set up input and output directories - noticed that, the checkpoint of sktb and dptb should be in different directory, and in train_dptb, - there should be a workflow to load correction model from nnsktb. -2- parse configuration file and start training output directories has following structure: @@ -65,132 +57,18 @@ def train( ... - log/ - log.log - - config_nnsktb.json - - config_dptb.json + - config.json ''' # init all paths # if init_model, restart or init_frez, findout the input configure file - - if all((use_correction, train_sk)): - raise RuntimeError( - "--use-correction and --train_sk should not be set at the same time" - ) # setup INPUT path - if train_sk: - if init_model: - skconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_nnsktb.json") - mode = "init_model" - elif restart: - skconfig_path = os.path.join(str(Path(restart).parent.absolute()), "config_nnsktb.json") - mode = "restart" - elif INPUT is not None: - log.info(msg="Haven't assign a initializing mode, training from scratch as default.") - mode = "from_scratch" - skconfig_path = INPUT - else: - log.error("ValueError: Missing Input configuration file path.") - raise ValueError - - # switch the init model mode from command line to config file - jdata = j_loader(INPUT) - jdata = normalize(jdata) - - # check if init_model in commandline and input json are in conflict. - - if all((jdata["init_model"]["path"], run_opt["init_model"])) or \ - all((jdata["init_model"]["path"], run_opt["restart"])): - raise RuntimeError( - "init-model in config and command line is in conflict, turn off one of then to avoid this error !" - ) - - if jdata["init_model"]["path"] is not None: - assert mode == "from_scratch" - run_opt["init_model"] = jdata["init_model"] - mode = "init_model" - if isinstance(run_opt["init_model"]["path"], str): - skconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_nnsktb.json") - else: # list - skconfig_path = [os.path.join(str(Path(path).parent.absolute()), "config_nnsktb.json") for path in run_opt["init_model"]["path"]] - elif run_opt["init_model"] is not None: - # format run_opt's init model to the format of jdata - assert mode == "init_model" - path = run_opt["init_model"] - run_opt["init_model"] = jdata["init_model"] - run_opt["init_model"]["path"] = path - - # handling exceptions when init_model path in config file is [] and [single file] - if mode == "init_model": - if isinstance(run_opt["init_model"]["path"], list): - if len(run_opt["init_model"]["path"])==0: - raise RuntimeError("Error! list mode init_model in config file cannot be empty!") - - else: - if init_model: - dptbconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_dptbtb.json") - mode = "init_model" - elif restart: - dptbconfig_path = os.path.join(str(Path(restart).parent.absolute()), "config_dptbtb.json") - mode = "restart" - elif INPUT is not None: - log.info(msg="Haven't assign a initializing mode, training from scratch as default.") - dptbconfig_path = INPUT - mode = "from_scratch" - else: - log.error("ValueError: Missing Input configuration file path.") - raise ValueError - - if use_correction: - skconfig_path = os.path.join(str(Path(use_correction).parent.absolute()), "config_nnsktb.json") - # skcheckpoint_path = str(Path(str(input(f"Enter skcheckpoint_path (default ./checkpoint/best_nnsk.pth): \n"))).absolute()) - else: - skconfig_path = None - - # parse INPUT file - jdata = j_loader(INPUT) - jdata = normalize(jdata) - - if all((jdata["init_model"]["path"], run_opt["init_model"])) or \ - all((jdata["init_model"]["path"], run_opt["restart"])): - raise RuntimeError( - "init-model in config and command line is in conflict, turn off one of then to avoid this error !" - ) - - if jdata["init_model"]["path"] is not None: - assert mode == "from_scratch" - log.info(msg="Init model is read from config rile.") - run_opt["init_model"] = jdata["init_model"] - mode = "init_model" - if isinstance(run_opt["init_model"]["path"], str): - dptbconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_dptb.json") - else: # list - raise RuntimeError( - "loading lists of checkpoints is only supported in init_nnsk!" - ) - elif run_opt["init_model"] is not None: - assert mode == "init_model" - path = run_opt["init_model"] - run_opt["init_model"] = jdata["init_model"] - run_opt["init_model"]["path"] = path - - if mode == "init_model": - if isinstance(run_opt["init_model"]["path"], list): - if len(run_opt["init_model"]["path"])==0: - log.error(msg="Error, no checkpoint supplied!") - raise RuntimeError - elif len(run_opt["init_model"]["path"])>1: - log.error(msg="Error! list mode init_model in config only support single file in DPTB!") - raise RuntimeError if all((run_opt["init_model"], restart)): raise RuntimeError( "--init-model and --restart should not be set at the same time" ) - if mode == "init_model": - if isinstance(run_opt["init_model"]["path"], list): - if len(run_opt["init_model"]["path"]) == 1: - run_opt["init_model"]["path"] = run_opt["init_model"]["path"][0] # setup output path if output: Path(output).parent.mkdir(exist_ok=True, parents=True) @@ -209,71 +87,137 @@ def train( "log_path": str(Path(log_path).absolute()) }) - run_opt.update({"mode": mode}) - if train_sk: - run_opt.update({ - "skconfig_path": skconfig_path, - }) - else: - if use_correction: - run_opt.update({ - "skconfig_path": skconfig_path - }) - run_opt.update({ - "dptbconfig_path": dptbconfig_path - }) - set_log_handles(log_level, Path(log_path) if log_path else None) # parse the config. Since if use init, config file may not equals to current + jdata = j_loader(INPUT) + jdata = normalize(jdata) + # update basis if init_model or restart + # update jdata + # this is not necessary, because if we init model from checkpoint, the build_model will load the model_options from checkpoints if not provided + # since here we want to output jdata as a config file to inform the user what model options are used, we need to update the jdata + torch.set_default_dtype(getattr(torch, jdata["common_options"]["dtype"])) + + if restart or init_model: + f = restart if restart else init_model + f = torch.load(f) + + if jdata.get("model_options", None) is None: + jdata["model_options"] = f["config"]["model_options"] + + # update basis + basis = f["config"]["common_options"]["basis"] + # nnsk + if len(f["config"]["model_options"])==1 and f["config"]["model_options"].get("nnsk") != None: + for asym, orb in jdata["common_options"]["basis"].items(): + assert asym in basis.keys(), f"Atom {asym} not found in model's basis" + if orb != basis[asym]: + log.info(f"Initializing Orbital {orb} of Atom {asym} from {basis[asym]}") + # we have the orbitals in jdata basis correct, now we need to make sure all atom in basis are also contained in jdata basis + for asym, orb in basis.items(): + if asym not in jdata["common_options"]["basis"].keys(): + jdata["common_options"]["basis"][asym] = orb # add the atomtype in the checkpoint but not in the jdata basis, because it will be used to build the orbital mapper for dataset + else: # not nnsk + for asym, orb in jdata["common_options"]["basis"].items(): + assert asym in basis.keys(), f"Atom {asym} not found in model's basis" + assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis, which is only allowed in nnsk training" + + jdata["common_options"]["basis"] = basis + + # update model options and train_options + if restart: + # + if jdata.get("train_options", None) is not None: + for obj in Trainer.object_keys: + if jdata["train_options"].get(obj) != f["config"]["train_options"].get(obj): + log.warning(f"{obj} in config file is not consistent with the checkpoint, using the one in checkpoint") + jdata["train_options"][obj] = f["config"]["train_options"][obj] + else: + jdata["train_options"] = f["config"]["train_options"] + + if jdata.get("model_options", None) is None or jdata["model_options"] != f["config"]["model_options"]: + log.warning("model_options in config file is not consistent with the checkpoint, using the one in checkpoint") + jdata["model_options"] = f["config"]["model_options"] # restart does not allow to change model options + else: + # init model mode, allow model_options change + if jdata.get("train_options", None) is None: + jdata["train_options"] = f["config"]["train_options"] + if jdata.get("model_options") is None: + jdata["model_options"] = f["config"]["model_options"] + del f + else: + j_must_have(jdata, "model_options") + j_must_have(jdata, "train_options") + + # setup seed - setup_seed(seed=jdata["train_options"]["seed"]) + setup_seed(seed=jdata["common_options"]["seed"]) # with open(os.path.join(output, "train_config.json"), "w") as fp: # json.dump(jdata, fp, indent=4) - - str_dtype = jdata["common_options"]["dtype"] - jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]] - if train_sk: - trainer = NNSKTrainer(run_opt, jdata) - trainer.register_plugin(InitSKModel()) + + # build dataset + train_datasets = build_dataset(set_options=jdata["data_options"]["train"], common_options=jdata["common_options"]) + if jdata["data_options"].get("validation"): + validation_datasets = build_dataset(set_options=jdata["data_options"]["validation"], common_options=jdata["common_options"]) else: - trainer = DPTBTrainer(run_opt, jdata) - trainer.register_plugin(InitDPTBModel()) - - + validation_datasets = None + if jdata["data_options"].get("reference"): + reference_datasets = build_dataset(set_options=jdata["data_options"]["reference"], common_options=jdata["common_options"]) + else: + reference_datasets = None + + if restart: + trainer = Trainer.restart( + train_options=jdata["train_options"], + common_options=jdata["common_options"], + checkpoint=restart, + train_datasets=train_datasets, + reference_datasets=reference_datasets, + validation_datasets=validation_datasets, + ) + else: + # include the init model and from scratch + # build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint. + model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"], statistics=train_datasets.E3statistics()) + trainer = Trainer( + train_options=jdata["train_options"], + common_options=jdata["common_options"], + model = model, + train_datasets=train_datasets, + validation_datasets=validation_datasets, + reference_datasets=reference_datasets, + ) # register the plugin in trainer, to tract training info - trainer.register_plugin(InitData()) - trainer.register_plugin(Validationer()) + log_field = ["train_loss", "lr"] + if validation_datasets: + trainer.register_plugin(Validationer()) + log_field.append("validation_loss") trainer.register_plugin(TrainLossMonitor()) trainer.register_plugin(LearningRateMonitor()) - trainer.register_plugin(Logger(["train_loss", "validation_loss", "lr"], + trainer.register_plugin(Logger(log_field, interval=[(jdata["train_options"]["display_freq"], 'iteration'), (1, 'epoch')])) for q in trainer.plugin_queues.values(): heapq.heapify(q) - - trainer.build() - if output: # output training configurations: with open(os.path.join(output, "train_config.json"), "w") as fp: - jdata["common_options"]["dtype"] = str_dtype json.dump(jdata, fp, indent=4) trainer.register_plugin(Saver( #interval=[(jdata["train_options"].get("save_freq"), 'epoch'), (1, 'iteration')] if jdata["train_options"].get( # "save_freq") else None)) interval=[(jdata["train_options"].get("save_freq"), 'iteration'), (1, 'epoch')] if jdata["train_options"].get( - "save_freq") else None)) + "save_freq") else None), checkpoint_path=checkpoint_path) # add a plugin to save the training parameters of the model, with model_output as given path start_time = time.time() - trainer.run(trainer.num_epoch) + trainer.run(trainer.train_options["num_epoch"]) end_time = time.time() log.info("finished training") diff --git a/dptb/hamiltonian/hamil_eig_sk.py b/dptb/hamiltonian/hamil_eig_sk.py deleted file mode 100644 index d73e7435..00000000 --- a/dptb/hamiltonian/hamil_eig_sk.py +++ /dev/null @@ -1,278 +0,0 @@ -import torch -import torch as th -import numpy as np -import logging -import re -from dptb.hamiltonian.transform_sk import RotationSK -from dptb.utils.constants import anglrMId - -''' Over use of different index system cause the symbols and type and index kind of object need to be recalculated in different -Class, this makes entanglement of classes difficult. Need to design an consistent index system to resolve.''' - -log = logging.getLogger(__name__) - -class HamilEig(RotationSK): - """ This module is to build the Hamiltonian from the SK-type bond integral. - """ - def __init__(self, dtype='tensor') -> None: - super().__init__(rot_type=dtype) - self.dtype = dtype - self.use_orthogonal_basis = False - self.hamil_blocks = None - self.overlap_blocks = None - - def update_hs_list(self, struct, hoppings, onsiteEs, overlaps=None, onsiteSs=None, **options): - '''It updates the bond structure, bond type, bond type id, bond hopping, bond onsite, hopping, onsite - energy, overlap, and onsite spin - - Parameters - ---------- - hoppings - a list bond integral for hoppings. - onsiteEs - a list of onsite energy for each atom and each orbital. - overlaps - a list bond integral for overlaps. - onsiteSs - a list of onsite overlaps for each atom and each orbital. - ''' - self.__struct__ = struct - self.hoppings = hoppings - self.onsiteEs = onsiteEs - self.use_orthogonal_basis = False - if overlaps is None: - self.use_orthogonal_basis = True - else: - self.overlaps = overlaps - self.onsiteSs = onsiteSs - self.use_orthogonal_basis = False - - self.num_orbs_per_atom = [] - for itype in self.__struct__.proj_atom_symbols: - norbs = self.__struct__.proj_atomtype_norbs[itype] - self.num_orbs_per_atom.append(norbs) - - def get_hs_blocks(self, bonds_onsite = None, bonds_hoppings=None): - """using the SK type bond integral to build the hamiltonian matrix and overlap matrix in the real space. - - The hamiltonian and overlap matrix block are stored in the order of bond list. for ecah bond ij, with lattice - vecto R, the matrix stored in [norbsi, norbsj]. norsbi and norbsj are the total number of orbtals on i and j sites. - e.g. for C-atom with both s and p orbital on each site. norbi is 4. - """ - if bonds_onsite is None: - bonds_onsite = self.__struct__.__bonds_onsite__ - assert len(bonds_onsite) == len(self.__struct__.__bonds_onsite__) - if bonds_hoppings is None: - bonds_hoppings = self.__struct__.__bonds__ - assert len(bonds_hoppings) == len(self.__struct__.__bonds__) - - hamil_blocks = [] - if not self.use_orthogonal_basis: - overlap_blocks = [] - for ib in range(len(bonds_onsite)): - ibond = bonds_onsite[ib].astype(int) - iatype = self.__struct__.proj_atom_symbols[ibond[1]] - jatype = self.__struct__.proj_atom_symbols[ibond[3]] - assert iatype == jatype, "i type should equal j type." - - if self.dtype == 'tensor': - sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - if not self.use_orthogonal_basis: - sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - else: - sub_hamil_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - if not self.use_orthogonal_basis: - sub_over_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - - # ToDo: adding onsite correction - ist = 0 - # replace sub_hamil_block from now the block diagonal formula to corrected ones. - for ish in self.__struct__.proj_atom_anglr_m[iatype]: # ['s','p',..] - ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish)) - shidi = anglrMId[ishsymbol] # 0,1,2,... - norbi = 2*shidi + 1 - - indx = self.__struct__.onsite_index_map[iatype][ish] # change onsite index map from {N:{s:}} to {N:{ss:, sp:}} - # this already satisfy for uniform onsite or splited onsite energy. - # e.g. for p orbital, index may be 1, or 1,2,3, stands for uniform or splited energy for px py pz. - # and then self.onsiteEs[ib][indx] can be scalar or torch.Size([1]) or torch.Size([3]). - # both of them can be transfer into a [3x3] diagonal matrix in this code. - if self.dtype == 'tensor': - sub_hamil_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi) * self.onsiteEs[ib][indx] - if not self.use_orthogonal_basis: - sub_over_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi) * self.onsiteSs[ib][indx] - else: - sub_hamil_block[ist:ist+norbi, ist:ist+norbi] = np.eye(norbi) * self.onsiteEs[ib][indx] - if not self.use_orthogonal_basis: - sub_over_block[ist:ist+norbi, ist:ist+norbi] = np.eye(norbi) * self.onsiteSs[ib][indx] - ist = ist +norbi - - hamil_blocks.append(sub_hamil_block) - if not self.use_orthogonal_basis: - overlap_blocks.append(sub_over_block) - - for ib in range(len(bonds_hoppings)): - - ibond = bonds_hoppings[ib,0:7].astype(int) - #direction_vec = (self.__struct__.projected_struct.positions[ibond[3]] - # - self.__struct__.projected_struct.positions[ibond[1]] - # + np.dot(ibond[4:], self.__struct__.projected_struct.cell)) - #dist = np.linalg.norm(direction_vec) - #direction_vec = direction_vec/dist - direction_vec = bonds_hoppings[ib,8:11].astype(np.float32) - iatype = self.__struct__.proj_atom_symbols[ibond[1]] - jatype = self.__struct__.proj_atom_symbols[ibond[3]] - - if self.dtype == 'tensor': - sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - if not self.use_orthogonal_basis: - sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - else: - sub_hamil_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - if not self.use_orthogonal_basis: - sub_over_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]]) - - bondatomtype = iatype + '-' + jatype - - ist = 0 - for ish in self.__struct__.proj_atom_anglr_m[iatype]: - ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish)) - shidi = anglrMId[ishsymbol] - norbi = 2*shidi+1 - - jst = 0 - for jsh in self.__struct__.proj_atom_anglr_m[jatype]: - jshsymbol = ''.join(re.findall(r'[A-Za-z]',jsh)) - shidj = anglrMId[jshsymbol] - norbj = 2 * shidj + 1 - - idx = self.__struct__.bond_index_map[bondatomtype][ish+'-'+jsh] - if shidi < shidj: - tmpH = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec) - # Hamilblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1) - if self.dtype == 'tensor': - sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpH,dim0=0,dim1=1) - else: - sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * np.transpose(tmpH,(1,0)) - if not self.use_orthogonal_basis: - tmpS = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.overlaps[ib][idx], Angvec=direction_vec) - # Soverblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpS,dim0=0,dim1=1) - if self.dtype == 'tensor': - sub_over_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpS,dim0=0,dim1=1) - else: - sub_over_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * np.transpose(tmpS,(1,0)) - else: - tmpH = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec) - sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = tmpH - if not self.use_orthogonal_basis: - tmpS = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue = self.overlaps[ib][idx], Angvec = direction_vec) - sub_over_block[ist:ist+norbi, jst:jst+norbj] = tmpS - - jst = jst + norbj - ist = ist + norbi - hamil_blocks.append(sub_hamil_block) - if not self.use_orthogonal_basis: - overlap_blocks.append(sub_over_block) - self.all_bonds = np.concatenate([bonds_onsite[:,0:7],bonds_hoppings[:,0:7]],axis=0) - self.all_bonds = self.all_bonds.astype(int) - self.hamil_blocks = hamil_blocks - if not self.use_orthogonal_basis: - self.overlap_blocks = overlap_blocks - - - def hs_block_R2k(self, kpoints, HorS='H', time_symm=True, dtype='tensor'): - '''The function takes in a list of Hamiltonian matrices for each bond, and a list of k-points, and - returns a list of Hamiltonian matrices for each k-point - - Parameters - ---------- - HorS - string, 'H' or 'S' to indicate for Hk or Sk calculation. - kpoints - the k-points in the path. - time_symm, optional - if True, the Hamiltonian is time-reversal symmetric, defaults to True (optional) - dtype, optional - 'tensor' or 'numpy', defaults to tensor (optional) - - Returns - ------- - A list of Hamiltonian or Overlap matrices for each k-point. - ''' - numOrbs = np.array(self.num_orbs_per_atom) - totalOrbs = np.sum(numOrbs) - if HorS == 'H': - hijAll = self.hamil_blocks - elif HorS == 'S': - hijAll = self.overlap_blocks - else: - print("HorS should be 'H' or 'S' !") - - if dtype == 'tensor': - Hk = th.zeros([len(kpoints), totalOrbs, totalOrbs], dtype = th.complex64) - else: - Hk = np.zeros([len(kpoints), totalOrbs, totalOrbs], dtype = np.complex64) - - for ik in range(len(kpoints)): - k = kpoints[ik] - if dtype == 'tensor': - hk = th.zeros([totalOrbs,totalOrbs],dtype = th.complex64) - else: - hk = np.zeros([totalOrbs,totalOrbs],dtype = np.complex64) - for ib in range(len(self.all_bonds)): - Rlatt = self.all_bonds[ib,4:7].astype(int) - i = self.all_bonds[ib,1].astype(int) - j = self.all_bonds[ib,3].astype(int) - ist = int(np.sum(numOrbs[0:i])) - ied = int(np.sum(numOrbs[0:i+1])) - jst = int(np.sum(numOrbs[0:j])) - jed = int(np.sum(numOrbs[0:j+1])) - if ib < len(numOrbs): - """ - len(numOrbs)= numatoms. the first numatoms are onsite energies. - if turn on timeSymm when generating the bond list . only i>= or <= j are included. - if turn off timeSymm when generating the bond list . all the i j are included. - for case 1, H = H+H^\dagger to get the full matrix, the the onsite one is doubled. - for case 2. no need to do H = H+H^dagger. since the matrix is already full. - """ - if time_symm: - hk[ist:ied,jst:jed] += 0.5 * hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt)) - else: - hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt)) - else: - hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt)) - if time_symm: - hk = hk + hk.T.conj() - Hk[ik] = hk - return Hk - - def Eigenvalues(self, kpoints, time_symm=True,dtype='tensor'): - """ using the tight-binding H and S matrix calculate eigenvalues at kpoints. - - Args: - kpoints: the k-kpoints used to calculate the eigenvalues. - Note: must have the BondHBlock and BondSBlock - """ - hkmat = self.hs_block_R2k(kpoints=kpoints, HorS='H', time_symm=time_symm, dtype=dtype) - if not self.use_orthogonal_basis: - skmat = self.hs_block_R2k(kpoints=kpoints, HorS='S', time_symm=time_symm, dtype=dtype) - else: - skmat = torch.eye(hkmat.shape[1], dtype=torch.complex64).unsqueeze(0).repeat(hkmat.shape[0], 1, 1) - - if self.dtype == 'tensor': - chklowt = th.linalg.cholesky(skmat) - chklowtinv = th.linalg.inv(chklowt) - Heff = (chklowtinv @ hkmat @ th.transpose(chklowtinv,dim0=1,dim1=2).conj()) - # the factor 13.605662285137 * 2 from Hartree to eV. - # eigks = th.linalg.eigvalsh(Heff) * 13.605662285137 * 2 - eigks, Q = th.linalg.eigh(Heff) - eigks = eigks * 13.605662285137 * 2 - Qres = Q.detach() - else: - chklowt = np.linalg.cholesky(skmat) - chklowtinv = np.linalg.inv(chklowt) - Heff = (chklowtinv @ hkmat @ np.transpose(chklowtinv,(0,2,1)).conj()) - eigks = np.linalg.eigvalsh(Heff) * 13.605662285137 * 2 - Qres = 0 - - return eigks, Qres \ No newline at end of file diff --git a/dptb/hamiltonian/hamil_eig_sk_crt_soc.py b/dptb/hamiltonian/hamil_eig_sk_crt_soc.py deleted file mode 100644 index 19493c18..00000000 --- a/dptb/hamiltonian/hamil_eig_sk_crt_soc.py +++ /dev/null @@ -1,395 +0,0 @@ -import torch -import torch as th -import numpy as np -import logging -import re -from dptb.hamiltonian.transform_sk import RotationSK -from dptb.nnsktb.formula import SKFormula -from dptb.utils.constants import anglrMId -from dptb.hamiltonian.soc import creat_basis_lm, get_soc_matrix_cubic_basis - -''' Over use of different index system cause the symbols and type and index kind of object need to be recalculated in different -Class, this makes entanglement of classes difficult. Need to design an consistent index system to resolve.''' - -log = logging.getLogger(__name__) - -class HamilEig(RotationSK): - """ This module is to build the Hamiltonian from the SK-type bond integral. - """ - def __init__(self, dtype=torch.float32, device='cpu') -> None: - super().__init__(rot_type=dtype, device=device) - self.dtype = dtype - if self.dtype is th.float32: - self.cdtype = th.complex64 - elif self.dtype is th.float64: - self.cdtype = th.complex128 - self.use_orthogonal_basis = False - self.hamil_blocks = None - self.overlap_blocks = None - self.device = device - - def update_hs_list(self, struct, hoppings, onsiteEs, onsiteVs=None, overlaps=None, onsiteSs=None, soc_lambdas=None, **options): - '''It updates the bond structure, bond type, bond type id, bond hopping, bond onsite, hopping, onsite - energy, overlap, and onsite spin - - Parameters - ---------- - hoppings - a list bond integral for hoppings. - onsiteEs - a list of onsite energy for each atom and each orbital. - overlaps - a list bond integral for overlaps. - onsiteSs - a list of onsite overlaps for each atom and each orbital. - ''' - self.__struct__ = struct - self.hoppings = hoppings - self.onsiteEs = onsiteEs - self.onsiteVs = onsiteVs - self.soc_lambdas = soc_lambdas - self.use_orthogonal_basis = False - if overlaps is None: - self.use_orthogonal_basis = True - else: - self.overlaps = overlaps - self.onsiteSs = onsiteSs - self.use_orthogonal_basis = False - - if soc_lambdas is None: - self.soc = False - else: - self.soc = True - - self.num_orbs_per_atom = [] - for itype in self.__struct__.proj_atom_symbols: - norbs = self.__struct__.proj_atomtype_norbs[itype] - self.num_orbs_per_atom.append(norbs) - - def get_soc_block(self, bonds_onsite = None): - numOrbs = np.array(self.num_orbs_per_atom) - totalOrbs = np.sum(numOrbs) - if bonds_onsite is None: - _, bonds_onsite = self.__struct__.get_bond() - - soc_upup = torch.zeros_like((totalOrbs, totalOrbs), device=self.device, dtype=self.cdtype) - soc_updown = torch.zeros_like((totalOrbs, totalOrbs), device=self.device, dtype=self.cdtype) - - # compute soc mat for each atom: - soc_atom_upup = self.__struct__.get("soc_atom_diag", {}) - soc_atom_updown = self.__struct__.get("soc_atom_up", {}) - if not soc_atom_upup or not soc_atom_updown: - for iatype in self.__struct__.proj_atomtype: - total_num_orbs_iatom= self.__struct__.proj_atomtype_norbs[iatype] - tmp_upup = torch.zeros([total_num_orbs_iatom, total_num_orbs_iatom], dtype=self.cdtype, device=self.device) - tmp_updown = torch.zeros([total_num_orbs_iatom, total_num_orbs_iatom], dtype=self.cdtype, device=self.device) - - ist = 0 - for ish in self.__struct__.proj_atom_anglr_m[iatype]: - ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish)) - shidi = anglrMId[ishsymbol] # 0,1,2,... - norbi = 2*shidi + 1 - - soc_orb = get_soc_matrix_cubic_basis(orbital=ishsymbol, device=self.device, dtype=self.dtype) - if len(soc_orb) != 2*norbi: - log.error(msg='The dimension of the soc_orb is not correct!') - tmp_upup[ist:ist+norbi, ist:ist+norbi] = soc_orb[:norbi,:norbi] - tmp_updown[ist:ist+norbi, ist:ist+norbi] = soc_orb[:norbi, norbi:] - ist = ist + norbi - - soc_atom_upup.update({iatype:tmp_upup}) - soc_atom_updown.update({iatype:tmp_updown}) - self.__struct__.soc_atom_upup = soc_atom_upup - self.__struct__.soc_atom_updown = soc_atom_updown - - for ib in range(len(bonds_onsite)): - ibond = bonds_onsite[ib].astype(int) - iatom = ibond[1] - ist = int(np.sum(numOrbs[0:iatom])) - ied = int(np.sum(numOrbs[0:iatom+1])) - iatype = self.__struct__.proj_atom_symbols[iatom] - - # get lambdas - ist = 0 - lambdas = torch.zeros((ied-ist,), device=self.device, dtype=self.dtype) - for ish in self.__struct__.proj_atom_anglr_m[iatype]: - indx = self.__struct__.onsite_index_map[iatype][ish] - ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish)) - shidi = anglrMId[ishsymbol] # 0,1,2,... - norbi = 2*shidi + 1 - lambdas[ist:ist+norbi] = self.soc_lambdas[ib][indx] - ist = ist + norbi - - soc_upup[ist:ied,ist:ied] = soc_atom_upup[iatype] * torch.diag(lambdas) - soc_updown[ist:ied, ist:ied] = soc_atom_updown[iatype] * torch.diag(lambdas) - - soc_upup.contiguous() - soc_updown.contiguous() - - return soc_upup, soc_updown - - def get_hs_onsite(self, bonds_onsite = None, onsite_envs=None): - if bonds_onsite is None: - _, bonds_onsite = self.__struct__.get_bond() - onsiteH_blocks = [] - if not self.use_orthogonal_basis: - onsiteS_blocks = [] - else: - onsiteS_blocks = None - - iatom_to_onsite_index = {} - for ib in range(len(bonds_onsite)): - ibond = bonds_onsite[ib].astype(int) - iatom = ibond[1] - iatom_to_onsite_index.update({iatom:ib}) - jatom = ibond[3] - iatype = self.__struct__.proj_atom_symbols[iatom] - jatype = self.__struct__.proj_atom_symbols[jatom] - assert iatype == jatype, "i type should equal j type." - - sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device) - if not self.use_orthogonal_basis: - sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device) - - ist = 0 - for ish in self.__struct__.proj_atom_anglr_m[iatype]: # ['s','p',..] - ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish)) - shidi = anglrMId[ishsymbol] # 0,1,2,... - norbi = 2*shidi + 1 - - indx = self.__struct__.onsite_index_map[iatype][ish] # change onsite index map from {N:{s:}} to {N:{ss:, sp:}} - sub_hamil_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi, dtype=self.dtype, device=self.device) * self.onsiteEs[ib][indx] - if not self.use_orthogonal_basis: - sub_over_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi, dtype=self.dtype, device=self.device) * self.onsiteSs[ib][indx] - ist = ist + norbi - - onsiteH_blocks.append(sub_hamil_block) - if not self.use_orthogonal_basis: - onsiteS_blocks.append(sub_over_block) - - # onsite strain - if onsite_envs is not None: - assert self.onsiteVs is not None - for ib, env in enumerate(onsite_envs): - - iatype, iatom, jatype, jatom = self.__struct__.proj_atom_symbols[int(env[1])], env[1], self.__struct__.atom_symbols[int(env[3])], env[3] - direction_vec = env[8:11].astype(np.float32) - - sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[iatype]], dtype=self.dtype, device=self.device) - - envtype = iatype + '-' + jatype - - ist = 0 - for ish in self.__struct__.proj_atom_anglr_m[iatype]: - ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish)) - shidi = anglrMId[ishsymbol] - norbi = 2*shidi+1 - - jst = 0 - for jsh in self.__struct__.proj_atom_anglr_m[iatype]: - jshsymbol = ''.join(re.findall(r'[A-Za-z]',jsh)) - shidj = anglrMId[jshsymbol] - norbj = 2 * shidj + 1 - - idx = self.__struct__.onsite_strain_index_map[envtype][ish+'-'+jsh] - - if shidi < shidj: - - tmpH = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.onsiteVs[ib][idx], Angvec=direction_vec) - # Hamilblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1) - sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1) - else: - tmpH = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue=self.onsiteVs[ib][idx], Angvec=direction_vec) - sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = tmpH - - jst = jst + norbj - ist = ist + norbi - onsiteH_blocks[iatom_to_onsite_index[iatom]] += sub_hamil_block - - return onsiteH_blocks, onsiteS_blocks, bonds_onsite - - def get_hs_hopping(self, bonds_hoppings = None): - if bonds_hoppings is None: - bonds_hoppings, _ = self.__struct__.get_bond() - - hoppingH_blocks = [] - if not self.use_orthogonal_basis: - hoppingS_blocks = [] - else: - hoppingS_blocks = None - - for ib in range(len(bonds_hoppings)): - - ibond = bonds_hoppings[ib,0:7].astype(int) - #direction_vec = (self.__struct__.projected_struct.positions[ibond[3]] - # - self.__struct__.projected_struct.positions[ibond[1]] - # + np.dot(ibond[4:], self.__struct__.projected_struct.cell)) - #dist = np.linalg.norm(direction_vec) - #direction_vec = direction_vec/dist - direction_vec = bonds_hoppings[ib,8:11].astype(np.float32) - iatype = self.__struct__.proj_atom_symbols[ibond[1]] - jatype = self.__struct__.proj_atom_symbols[ibond[3]] - - sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device) - if not self.use_orthogonal_basis: - sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device) - - bondatomtype = iatype + '-' + jatype - - ist = 0 - for ish in self.__struct__.proj_atom_anglr_m[iatype]: - ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish)) - shidi = anglrMId[ishsymbol] - norbi = 2*shidi+1 - - jst = 0 - for jsh in self.__struct__.proj_atom_anglr_m[jatype]: - jshsymbol = ''.join(re.findall(r'[A-Za-z]',jsh)) - shidj = anglrMId[jshsymbol] - norbj = 2 * shidj + 1 - - idx = self.__struct__.bond_index_map[bondatomtype][ish+'-'+jsh] - if shidi < shidj: - tmpH = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec) - # Hamilblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1) - sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpH,dim0=0,dim1=1) - if not self.use_orthogonal_basis: - tmpS = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.overlaps[ib][idx], Angvec=direction_vec) - # Soverblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpS,dim0=0,dim1=1) - sub_over_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpS,dim0=0,dim1=1) - else: - tmpH = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec) - sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = tmpH - if not self.use_orthogonal_basis: - tmpS = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue = self.overlaps[ib][idx], Angvec = direction_vec) - sub_over_block[ist:ist+norbi, jst:jst+norbj] = tmpS - - jst = jst + norbj - ist = ist + norbi - - hoppingH_blocks.append(sub_hamil_block) - if not self.use_orthogonal_basis: - hoppingS_blocks.append(sub_over_block) - - return hoppingH_blocks, hoppingS_blocks, bonds_hoppings - - def get_hs_blocks(self, bonds_onsite = None, bonds_hoppings=None, onsite_envs=None): - onsiteH, onsiteS, bonds_onsite = self.get_hs_onsite(bonds_onsite=bonds_onsite, onsite_envs=onsite_envs) - hoppingH, hoppingS, bonds_hoppings = self.get_hs_hopping(bonds_hoppings=bonds_hoppings) - - self.all_bonds = np.concatenate([bonds_onsite[:,0:7],bonds_hoppings[:,0:7]],axis=0) - self.all_bonds = self.all_bonds.astype(int) - onsiteH.extend(hoppingH) - self.hamil_blocks = onsiteH - if not self.use_orthogonal_basis: - onsiteS.extend(hoppingS) - self.overlap_blocks = onsiteS - if self.soc: - self.soc_upup, self.soc_updown = self.get_soc_block(bonds_onsite=bonds_onsite) - - return True - - def hs_block_R2k(self, kpoints, HorS='H', time_symm=True): - '''The function takes in a list of Hamiltonian matrices for each bond, and a list of k-points, and - returns a list of Hamiltonian matrices for each k-point - - Parameters - ---------- - HorS - string, 'H' or 'S' to indicate for Hk or Sk calculation. - kpoints - the k-points in the path. - time_symm, optional - if True, the Hamiltonian is time-reversal symmetric, defaults to True (optional) - dtype, optional - 'tensor' or 'numpy', defaults to tensor (optional) - - Returns - ------- - A list of Hamiltonian or Overlap matrices for each k-point. - ''' - - numOrbs = np.array(self.num_orbs_per_atom) - totalOrbs = np.sum(numOrbs) - if HorS == 'H': - hijAll = self.hamil_blocks - elif HorS == 'S': - hijAll = self.overlap_blocks - else: - print("HorS should be 'H' or 'S' !") - - if self.soc: - Hk = th.zeros([len(kpoints), 2*totalOrbs, 2*totalOrbs], dtype = self.cdtype, device=self.device) - else: - Hk = th.zeros([len(kpoints), totalOrbs, totalOrbs], dtype = self.cdtype, device=self.device) - - for ik in range(len(kpoints)): - k = kpoints[ik] - hk = th.zeros([totalOrbs,totalOrbs],dtype = self.cdtype, device=self.device) - for ib in range(len(self.all_bonds)): - Rlatt = self.all_bonds[ib,4:7].astype(int) - i = self.all_bonds[ib,1].astype(int) - j = self.all_bonds[ib,3].astype(int) - ist = int(np.sum(numOrbs[0:i])) - ied = int(np.sum(numOrbs[0:i+1])) - jst = int(np.sum(numOrbs[0:j])) - jed = int(np.sum(numOrbs[0:j+1])) - if ib < len(numOrbs): - """ - len(numOrbs)= numatoms. the first numatoms are onsite energies. - if turn on timeSymm when generating the bond list . only i>= or <= j are included. - if turn off timeSymm when generating the bond list . all the i j are included. - for case 1, H = H+H^\dagger to get the full matrix, the the onsite one is doubled. - for case 2. no need to do H = H+H^dagger. since the matrix is already full. - """ - if time_symm: - hk[ist:ied,jst:jed] += 0.5 * hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt)) - else: - hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt)) - else: - hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt)) - if time_symm: - hk = hk + hk.T.conj() - if self.soc: - hk = torch.kron(A=torch.eye(2, device=self.device, dtype=self.dtype), B=hk) - Hk[ik] = hk - - if self.soc: - Hk[:, :totalOrbs, :totalOrbs] += self.soc_upup.unsqueeze(0) - Hk[:, totalOrbs:, totalOrbs:] += self.soc_upup.conj().unsqueeze(0) - Hk[:, :totalOrbs, totalOrbs:] += self.soc_updown.unsqueeze(0) - Hk[:, totalOrbs:, :totalOrbs] += self.soc_updown.conj().unsqueeze(0) - - Hk.contiguous() - - return Hk - - def Eigenvalues(self, kpoints, time_symm=True): - """ using the tight-binding H and S matrix calculate eigenvalues at kpoints. - - Args: - kpoints: the k-kpoints used to calculate the eigenvalues. - Note: must have the BondHBlock and BondSBlock - """ - hkmat = self.hs_block_R2k(kpoints=kpoints, HorS='H', time_symm=time_symm) - if not self.use_orthogonal_basis: - skmat = self.hs_block_R2k(kpoints=kpoints, HorS='S', time_symm=time_symm) - else: - skmat = torch.eye(hkmat.shape[1], dtype=self.cdtype).unsqueeze(0).repeat(hkmat.shape[0], 1, 1) - - chklowt = th.linalg.cholesky(skmat) - chklowtinv = th.linalg.inv(chklowt) - Heff = (chklowtinv @ hkmat @ th.transpose(chklowtinv,dim0=1,dim1=2).conj()) - # the factor 13.605662285137 * 2 from Hartree to eV. - # eigks = th.linalg.eigvalsh(Heff) * 13.605662285137 * 2 - eigks, Q = th.linalg.eigh(Heff) - eigks = eigks * 13.605662285137 * 2 - Qres = Q.detach() - # else: - # chklowt = np.linalg.cholesky(skmat) - # chklowtinv = np.linalg.inv(chklowt) - # Heff = (chklowtinv @ hkmat @ np.transpose(chklowtinv,(0,2,1)).conj()) - # eigks = np.linalg.eigvalsh(Heff) * 13.605662285137 * 2 - # Qres = 0 - - return eigks, Qres \ No newline at end of file diff --git a/dptb/nn/__init__.py b/dptb/nn/__init__.py new file mode 100644 index 00000000..4aa0820a --- /dev/null +++ b/dptb/nn/__init__.py @@ -0,0 +1,140 @@ +from .build import build_model +from .deeptb import DPTB +from .nnsk import NNSK + +__all__ = [ + build_model, + DPTB, + NNSK, +] +""" + +nn module is the model class for the graph neural network model, which is the core of the deeptb package. +It provide two interfaces which is used to interact with other module: +1. The build_model method, which is used to construct a model based on the model_options, common_options and run_options. + - the model options decides the structure of the model, such as the number of layers, the activation function, the number of neurons in each layer, etc. + - the common options contains some common parameters, such as the dtype, device, and the basis, which also related to the model + - the run options decide how to initialize the model. Whether it is from scratch, init from checkpoint, freeze or not, or whether to deploy it. + The build model method will return a model class and a config dict. + +2. A config dict of the model, which contains the essential information of the model to be initialized again. + +The build model method should composed of the following steps: +1. process the configs from user input and the config from the checkpoint (if any). +2. construct the model based on the configs. +3. process the config dict for the output dict. + +The deeptb model can be constructed by the following steps (which have been conpacted in deeptb.py): +1. choose the way to construct edge and atom embedding, either a descriptor, GNN or both. + - in: data with env and edge vectors, and atomic numbers + - out: data with edge and atom embedding + - user view: this can be defined as a tag in model_options +2. constructing the prediction layer, which named as sktb layer or e3tb layer, it is either a linear layer or a neural network + - in: data with edge and atom embedding + - out: data with the e3 irreducible matrix element, or the sk parameters +3. constructing hamiltonian model, either a SKTB or E3TB + - in: data with properties/parameters predicted + - out data with SK/E3 hamiltonian + +model_options = { + "embedding": { + "mode":"se2/gnn/se3...", + # mode specific + # se2 + "env_cutoff": 3.5, + "rs": float, + "rc": float, + "n_axis": int, + "radial_embedding": { + "neurons": [int], + "activation": str, + "if_batch_normalized": bool + } + # gnn + # se3 + }, + "prediction": { + "mode": "linear/nn", + # linear + # nn + "neurons": [int], + "activation": str, + "if_batch_normalized": bool, + "hamiltonian" = { + "method": "sktb/e3tb", + "rmax": 3.5, + "precision": float, # use to check if rmax is large enough + "soc": bool, + "overlap": bool, + # sktb + # e3tb + }, + }, + "nnsk":{ + "hopping_function": { + "formula": "varTang96/powerlaw/NRL", + ... + }, + "onsite_function": { + "formula": "strain/uniform/NRL", + # strain + "strain_cutoff": float, + # NRL + "cutoff": float, + "decay_w": float, + "lambda": float + } + } +} +""" + +common_options = { + "basis": { + "B": "2s2p1d", + "N": "2s2p1d", + }, + "device": "cpu", + "dtype": "float32", + "r_max": 2.0, + "er_max": 4.0, + "oer_max": 6.0, +} + + +data_options = { + "train": { + + } +} + + +dptb_model_options = { + "embedding": { + "method": "se2", + "rs": 2.0, + "rc": 7.0, + "n_axis": 10, + "radial_embedding": { + "neurons": [128,128,20], + "activation": "tanh", + "if_batch_normalized": False, + }, + }, + "prediction":{ + "method": "nn", + "neurons": [256,256,256], + "activation": "tanh", + "if_batch_normalized": False, + "quantities": ["hamiltonian"], + "hamiltonian":{ + "method": "e3tb", + "precision": 1e-5, + "overlap": False, + }, + }, + "nnsk": { + "onsite": {"method": "strain", "rs":6.0, "w":0.1}, + "hopping": {"method": "powerlaw", "rs":3.2, "w": 0.15}, + "overlap": False + } +} \ No newline at end of file diff --git a/dptb/nn/base.py b/dptb/nn/base.py new file mode 100644 index 00000000..a11e0172 --- /dev/null +++ b/dptb/nn/base.py @@ -0,0 +1,464 @@ +from torch.nn import Linear +import torch +from dptb.data import AtomicDataDict +from typing import Optional, Any, Union, Callable, OrderedDict, List +from torch import Tensor +from dptb.utils.constants import dtype_dict +from dptb.utils.tools import _get_activation_fn +import torch.nn.functional as F +import torch.nn as nn + +class AtomicLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + in_field = AtomicDataDict.NODE_FEATURES_KEY, + out_field = AtomicDataDict.NODE_FEATURES_KEY, + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu") + ): + super(AtomicLinear, self).__init__() + if isinstance(device, str): + device = torch.device(device) + if isinstance(dtype, str): + dtype = dtype_dict[dtype] + self.linear = Linear(in_features, out_features, dtype=dtype, device=device) + self.in_field = in_field + self.out_field = out_field + + def forward(self, data: AtomicDataDict.Type): + data[self.out_field] = self.linear(data[self.in_field]) + return data + +class Identity(torch.nn.Module): + def __init__( + self, + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu"), + **kwargs, + ): + super(Identity, self).__init__() + + def forward(self, data: AtomicDataDict) -> AtomicDataDict: + return data + + +class AtomicMLP(torch.nn.Module): + def __init__( + self, + in_features, + hidden_features, + out_features, + in_field = AtomicDataDict.NODE_FEATURES_KEY, + out_field = AtomicDataDict.NODE_FEATURES_KEY, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + if_batch_normalized: bool = False, + device: Union[str, torch.device] = torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32 + ): + super(AtomicMLP, self).__init__() + if isinstance(device, str): + device = torch.device(device) + if isinstance(dtype, str): + dtype = dtype_dict[dtype] + self.in_layer = Linear( + in_features=in_features, + out_features=hidden_features, + device=device, + dtype=dtype) + + self.out_layer = Linear( + in_features=hidden_features, + out_features=out_features, + device=device, + dtype=dtype) + + if if_batch_normalized: + self.bn1 = torch.nn.BatchNorm1d(hidden_features) + self.bn2 = torch.nn.BatchNorm1d(out_features) + self.if_batch_normalized = if_batch_normalized + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + self.in_field = in_field + self.out_field = out_field + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(AtomicMLP, self).__setstate__(state) + + def forward(self, data: AtomicDataDict.Type): + x = self.in_layer(data[self.in_field]) + if self.if_batch_normalized: + x = self.bn1(x) + x = self.activation(x) + x = self.out_layer(x) + if self.if_batch_normalized: + x = self.bn2(x) + data[self.out_field] = x + + return data + +class AtomicFFN(torch.nn.Module): + def __init__( + self, + config: List[dict], + in_field: AtomicDataDict.NODE_FEATURES_KEY, + out_field: AtomicDataDict.NODE_FEATURES_KEY, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + if_batch_normalized: bool = False, + device: Union[str, torch.device] = torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32, + **kwargs + ): + super(AtomicFFN, self).__init__() + self.layers = torch.nn.ModuleList([]) + for kk in range(len(config)-1): + if kk == 0: + self.layers.append( + AtomicMLP( + **config[kk], + in_field=in_field, + out_field=out_field, + if_batch_normalized=if_batch_normalized, + activation=activation, + device=device, + dtype=dtype + ) + ) + else: + self.layers.append( + AtomicMLP( + **config[kk], + in_field=out_field, + out_field=out_field, + if_batch_normalized=if_batch_normalized, + activation=activation, + device=device, + dtype=dtype + ) + ) + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + if config[-1].get('hidden_features') is None: + self.out_layer = AtomicLinear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], in_field=out_field, out_field=out_field, device=device, dtype=dtype) + else: + self.out_layer = AtomicMLP(**config[-1], in_field=out_field, out_field=out_field, if_batch_normalized=False, activation=activation, device=device, dtype=dtype) + self.out_field = out_field + self.in_field = in_field + # self.out_norm = nn.LayerNorm(config[-1]['out_features'], elementwise_affine=True) + + def forward(self, data: AtomicDataDict.Type): + out_scale = self.out_scale(data[self.in_field]) + out_shift = self.out_shift(data[self.in_field]) + for layer in self.layers: + data = layer(data) + data[self.out_field] = self.activation(data[self.out_field]) + + data = self.out_layer(data) + # data[self.out_field] = self.out_norm(data[self.out_field]) + return data + + +class AtomicResBlock(torch.nn.Module): + def __init__(self, + in_features: int, + hidden_features: int, + out_features: int, + in_field = AtomicDataDict.NODE_FEATURES_KEY, + out_field = AtomicDataDict.NODE_FEATURES_KEY, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + if_batch_normalized: bool=False, + device: Union[str, torch.device] = torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32 + ): + + super(AtomicResBlock, self).__init__() + self.in_field = in_field + self.out_field = out_field + self.layer = AtomicMLP(in_features, hidden_features, out_features, in_field=in_field, out_field=out_field, if_batch_normalized=if_batch_normalized, device=device, dtype=dtype, activation=activation) + self.out_features = out_features + self.in_features = in_features + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(AtomicResBlock, self).__setstate__(state) + + def forward(self, data: AtomicDataDict.Type): + if self.in_features < self.out_features: + res = F.interpolate(data[self.in_field].unsqueeze(1), size=[self.out_features]).squeeze(1) + elif self.in_features == self.out_features: + res = data[self.in_field] + else: + res = F.adaptive_avg_pool1d(input=data[self.in_field], output_size=self.out_features) + + data = self.layer(data) + data[self.out_field] = data[self.out_field] + res + + data[self.out_field] = self.activation(data[self.out_field]) + + return data + +# The ResNet class is a neural network model that consists of multiple residual blocks and a final +# output layer, with options for activation functions and batch normalization. + +class AtomicResNet(torch.nn.Module): + def __init__( + self, + config: List[dict], + in_field: AtomicDataDict.NODE_FEATURES_KEY, + out_field: AtomicDataDict.NODE_FEATURES_KEY, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + if_batch_normalized: bool = False, + device: Union[str, torch.device] = torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32, + **kwargs, + ): + """_summary_ + + Parameters + ---------- + config : list + ep: config = [ + {'in_features': 3, 'hidden_features': 4, 'out_features': 8}, + {'in_features': 8, 'hidden_features': 6, 'out_features': 4} + ] + activation : _type_ + _description_ + if_batch_normalized : bool, optional + _description_, by default False + device : str, optional + _description_, by default 'cpu' + dtype : _type_, optional + _description_, by default torch.float32 + """ + super(AtomicResNet, self).__init__() + self.in_field = in_field + self.out_field = out_field + self.layers = torch.nn.ModuleList([]) + for kk in range(len(config)-1): + # the first layer will take the in_field as key to take `data[in_field]` and output the out_field, data[out_field] = layer(data[in_field]) + # the rest of the layers will take the out_field as key to take `data[out_field]` and output the out_field, data[out_field] = layer(data[out_field]) + # That why we need to set the in_field and out_field for 1st layer and the rest of the layers. + if kk == 0: + self.layers.append( + AtomicResBlock( + **config[kk], + in_field=in_field, + out_field=out_field, + if_batch_normalized=if_batch_normalized, + activation=activation, + device=device, + dtype=dtype + ) + ) + else: + self.layers.append( + AtomicResBlock( + **config[kk], + in_field=out_field, + out_field=out_field, + if_batch_normalized=if_batch_normalized, + activation=activation, + device=device, + dtype=dtype + ) + ) + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + + if config[-1].get('hidden_feature') is None: + self.out_layer = AtomicLinear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], in_field=out_field, out_field=out_field, device=device, dtype=dtype) + else: + self.out_layer = AtomicMLP(**config[-1], if_batch_normalized=False, in_field=in_field, out_field=out_field, activation=activation, device=device, dtype=dtype) + # self.out_norm = nn.LayerNorm(config[-1]['out_features'], elementwise_affine=True) + + def forward(self, data: AtomicDataDict.Type): + + for layer in self.layers: + data = layer(data) + data[self.out_field] = self.activation(data[self.out_field]) + data = self.out_layer(data) + # data[self.out_field] = self.out_norm(data[self.out_field]) + return data + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features, + out_features, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + if_batch_normalized=False, + device: Union[str, torch.device]=torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32, + ): + super(MLP, self).__init__() + + if isinstance(device, str): + device = torch.device(device) + if isinstance(dtype, str): + dtype = dtype_dict[dtype] + + self.in_layer = nn.Linear(in_features=in_features, out_features=hidden_features, device=device, dtype=dtype) + self.out_layer = nn.Linear(in_features=hidden_features, out_features=out_features, device=device, dtype=dtype) + + if if_batch_normalized: + self.bn1 = nn.BatchNorm1d(hidden_features) + self.bn2 = nn.BatchNorm1d(out_features) + self.if_batch_normalized = if_batch_normalized + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(MLP, self).__setstate__(state) + + def forward(self, x): + x = self.in_layer(x) + if self.if_batch_normalized: + x = self.bn1(x) + x = self.activation(x) + x = self.out_layer(x) + if self.if_batch_normalized: + x = self.bn2(x) + + return x + +class FFN(nn.Module): + def __init__( + self, + config, + activation, + if_batch_normalized=False, + device: Union[str, torch.device]=torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32, + **kwargs + ): + super(FFN, self).__init__() + if isinstance(device, str): + device = torch.device(device) + if isinstance(dtype, str): + dtype = dtype_dict[dtype] + + self.layers = nn.ModuleList([]) + for kk in range(len(config)-1): + self.layers.append(MLP(**config[kk], if_batch_normalized=if_batch_normalized, activation=activation, device=device, dtype=dtype)) + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + if config[-1].get('hidden_features') is None: + self.out_layer = nn.Linear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], device=device, dtype=dtype) + # nn.init.normal_(self.out_layer.weight, mean=0, std=1e-3) + # nn.init.normal_(self.out_layer.bias, mean=0, std=1e-3) + else: + self.out_layer = MLP(**config[-1], if_batch_normalized=False, activation=activation, device=device, dtype=dtype) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = self.activation(x) + + return self.out_layer(x) + + +class ResBlock(torch.nn.Module): + def __init__( + self, + in_features, + hidden_features, + out_features, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + if_batch_normalized=False, + device: Union[str, torch.device]=torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32, + ): + super(ResBlock, self).__init__() + if isinstance(device, str): + device = torch.device(device) + if isinstance(dtype, str): + dtype = dtype_dict[dtype] + + self.layer = MLP(in_features, hidden_features, out_features, if_batch_normalized=if_batch_normalized, device=device, dtype=dtype, activation=activation) + self.out_features = out_features + self.in_features = in_features + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + pass + # super(ResBlock, self).__setstate__(state) + + def forward(self, x): + out = self.layer(x) + if self.in_features < self.out_features: + out = nn.functional.interpolate(x.unsqueeze(1), size=[self.out_features]).squeeze(1) + out + elif self.in_features == self.out_features: + out = x + out + else: + out = nn.functional.adaptive_avg_pool1d(input=x, output_size=self.out_features) + out + + out = self.activation(out) + + return out + +class ResNet(torch.nn.Module): + def __init__( + self, + config, + activation, + if_batch_normalized=False, + device: Union[str, torch.device]=torch.device('cpu'), + dtype: Union[str, torch.dtype] = torch.float32, + **kwargs + ): + super(ResNet, self).__init__() + if isinstance(device, str): + device = torch.device(device) + if isinstance(dtype, str): + dtype = dtype_dict[dtype] + + self.layers = torch.nn.ModuleList([]) + for kk in range(len(config)-1): + self.layers.append(ResBlock(**config[kk], if_batch_normalized=if_batch_normalized, activation=activation, device=device, dtype=dtype)) + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + + if config[-1].get('hidden_features') is None: + self.out_layer = nn.Linear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], device=device, dtype=dtype) + # nn.init.normal_(self.out_layer.weight, mean=0, std=1e-3) + # nn.init.normal_(self.out_layer.bias, mean=0, std=1e-3) + else: + self.out_layer = MLP(**config[-1], if_batch_normalized=False, activation=activation, device=device, dtype=dtype) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = self.activation(x) + + return self.out_layer(x) \ No newline at end of file diff --git a/dptb/nn/build.py b/dptb/nn/build.py new file mode 100644 index 00000000..3ab9fa46 --- /dev/null +++ b/dptb/nn/build.py @@ -0,0 +1,101 @@ +from dptb.nn.deeptb import DPTB, MIX +import logging +from dptb.nn.nnsk import NNSK +import torch +from dptb.utils.tools import j_must_have + +log = logging.getLogger(__name__) + +def build_model(run_options, model_options: dict={}, common_options: dict={}, statistics: dict=None): + """ + The build model method should composed of the following steps: + 1. process the configs from user input and the config from the checkpoint (if any). + 2. construct the model based on the configs. + 3. process the config dict for the output dict. + run_opt = { + "init_model": init_model, + "restart": restart, + "freeze": freeze, + "log_path": log_path, + "log_level": log_level, + "use_correction": use_correction + } + """ + # this is the + # process the model_options + assert not all((run_options.get("init_model"), run_options.get("restart"))), "You can only choose one of the init_model and restart options." + if any((run_options.get("init_model"), run_options.get("restart"))): + from_scratch = False + checkpoint = run_options.get("init_model") or run_options.get("restart") + else: + from_scratch = True + if not all((model_options, common_options)): + logging.error("You need to provide model_options and common_options when you are initializing a model from scratch.") + raise ValueError + + # decide whether to initialize a mixed model, or a deeptb model, or a nnsk model + init_deeptb = False + init_nnsk = False + init_mixed = False + + # load the model_options and common_options from checkpoint if not provided + if not from_scratch: + # init model from checkpoint + if len(model_options) == 0: + f = torch.load(checkpoint) + model_options = f["config"]["model_options"] + del f + + if len(common_options) == 0: + f = torch.load(checkpoint) + common_options = f["config"]["common_options"] + del f + + if all((all((model_options.get("embedding"), model_options.get("prediction"))), model_options.get("nnsk"))): + init_mixed = True + elif all((model_options.get("embedding"), model_options.get("prediction"))): + init_deeptb = True + elif model_options.get("nnsk"): + init_nnsk = True + else: + log.error("Model cannot be built without either one of the terms in model_options (embedding+prediction/nnsk).") + raise ValueError + + assert int(init_mixed) + int(init_deeptb) + int(init_nnsk) == 1, "You can only choose one of the mixed, deeptb, and nnsk options." + # check if the model is deeptb or nnsk + + # init deeptb + if from_scratch: + if init_deeptb: + model = DPTB(**model_options, **common_options) + + # do initialization from statistics if DPTB is e3tb and statistics is provided + if model.method == "e3tb" and statistics is not None: + scalar_mask = torch.BoolTensor([ir.dim==1 for ir in model.idp.orbpair_irreps]) + node_shifts = statistics["node"]["scalar_ave"] + node_scales = statistics["node"]["norm_ave"] + node_scales[:,scalar_mask] = statistics["node"]["scalar_std"] + + edge_shifts = statistics["edge"]["scalar_ave"] + edge_scales = statistics["edge"]["norm_ave"] + edge_scales[:,scalar_mask] = statistics["edge"]["scalar_std"] + model.node_prediction_h.set_scale_shift(scales=node_scales, shifts=node_shifts) + model.edge_prediction_h.set_scale_shift(scales=edge_scales, shifts=edge_shifts) + + if init_nnsk: + model = NNSK(**model_options["nnsk"], **common_options) + + if init_mixed: + model = MIX(**model_options, **common_options) + + else: + # load the model from the checkpoint + if init_deeptb: + model = DPTB.from_reference(checkpoint, **model_options, **common_options) + if init_nnsk: + model = NNSK.from_reference(checkpoint, **model_options["nnsk"], **common_options) + if init_mixed: + # mix model can be initilized with a mixed reference model or a nnsk model. + model = MIX.from_reference(checkpoint, **model_options, **common_options) + + return model diff --git a/dptb/nn/cutoff.py b/dptb/nn/cutoff.py new file mode 100644 index 00000000..76b78014 --- /dev/null +++ b/dptb/nn/cutoff.py @@ -0,0 +1,55 @@ +import math +import torch + + +@torch.jit.script +def cosine_cutoff(x: torch.Tensor, r_max: torch.Tensor, r_start_cos_ratio: float = 0.8): + """A piecewise cosine cutoff starting the cosine decay at r_decay_factor*r_max. + + Broadcasts over r_max. + """ + r_max, x = torch.broadcast_tensors(r_max.unsqueeze(-1), x.unsqueeze(0)) + r_decay: torch.Tensor = r_start_cos_ratio * r_max + # for x < r_decay, clamps to 1, for x > r_max, clamps to 0 + x = x.clamp(r_decay, r_max) + return 0.5 * (torch.cos((math.pi / (r_max - r_decay)) * (x - r_decay)) + 1.0) + + +@torch.jit.script +def polynomial_cutoff( + x: torch.Tensor, r_max: torch.Tensor, p: float = 6.0 +) -> torch.Tensor: + """Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123 + + + Parameters + ---------- + r_max : tensor + Broadcasts over r_max. + + p : int + Power used in envelope function + """ + assert p >= 2.0 + r_max, x = torch.broadcast_tensors(r_max.unsqueeze(-1), x.unsqueeze(0)) + x = x / r_max + + out = 1.0 + out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p)) + out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0)) + out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0)) + + return out * (x < 1.0) + +@torch.jit.script +def polynomial_cutoff2( + r: torch.Tensor, rc: torch.Tensor, rs: torch.Tensor, +) -> torch.Tensor: + + r_ = torch.zeros_like(r) + r_[r