diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py
new file mode 100644
index 00000000..0163a457
--- /dev/null
+++ b/dptb/data/AtomicData.py
@@ -0,0 +1,993 @@
+"""AtomicData: neighbor graphs in (periodic) real space.
+
+Authors: Albert Musaelian
+"""
+
+import warnings
+from copy import deepcopy
+from typing import Union, Tuple, Dict, Optional, List, Set, Sequence
+from collections.abc import Mapping
+import os
+
+import numpy as np
+import ase.neighborlist
+import ase
+from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator
+from ase.calculators.calculator import all_properties as ase_all_properties
+from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress
+
+import torch
+import e3nn.o3
+
+from . import AtomicDataDict
+from .util import _TORCH_INTEGER_DTYPES
+from dptb.utils.torch_geometric.data import Data
+
+# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
+PBC = Union[bool, Tuple[bool, bool, bool]]
+
+
+_DEFAULT_LONG_FIELDS: Set[str] = {
+ AtomicDataDict.EDGE_INDEX_KEY,
+ AtomicDataDict.ENV_INDEX_KEY, # new
+ AtomicDataDict.ONSITENV_INDEX_KEY, # new
+ AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ AtomicDataDict.ATOM_TYPE_KEY,
+ AtomicDataDict.BATCH_KEY,
+}
+
+_DEFAULT_NODE_FIELDS: Set[str] = {
+ AtomicDataDict.POSITIONS_KEY,
+ AtomicDataDict.NODE_FEATURES_KEY,
+ AtomicDataDict.NODE_ATTRS_KEY,
+ AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ AtomicDataDict.ATOM_TYPE_KEY,
+ AtomicDataDict.FORCE_KEY,
+ AtomicDataDict.PER_ATOM_ENERGY_KEY,
+ AtomicDataDict.NODE_HAMILTONIAN_KEY,
+ AtomicDataDict.NODE_OVERLAP_KEY,
+ AtomicDataDict.BATCH_KEY,
+}
+
+_DEFAULT_EDGE_FIELDS: Set[str] = {
+ AtomicDataDict.EDGE_CELL_SHIFT_KEY,
+ AtomicDataDict.EDGE_VECTORS_KEY,
+ AtomicDataDict.EDGE_LENGTH_KEY,
+ AtomicDataDict.EDGE_ATTRS_KEY,
+ AtomicDataDict.EDGE_EMBEDDING_KEY,
+ AtomicDataDict.EDGE_FEATURES_KEY,
+ AtomicDataDict.EDGE_CUTOFF_KEY,
+ AtomicDataDict.EDGE_ENERGY_KEY,
+ AtomicDataDict.EDGE_OVERLAP_KEY,
+ AtomicDataDict.EDGE_HAMILTONIAN_KEY,
+ AtomicDataDict.EDGE_TYPE_KEY,
+}
+
+_DEFAULT_ENV_FIELDS: Set[str] = {
+ AtomicDataDict.ENV_CELL_SHIFT_KEY,
+ AtomicDataDict.ENV_VECTORS_KEY,
+ AtomicDataDict.ENV_LENGTH_KEY,
+ AtomicDataDict.ENV_ATTRS_KEY,
+ AtomicDataDict.ENV_EMBEDDING_KEY,
+ AtomicDataDict.ENV_FEATURES_KEY,
+ AtomicDataDict.ENV_CUTOFF_KEY,
+}
+
+_DEFAULT_ONSITENV_FIELDS: Set[str] = {
+ AtomicDataDict.ONSITENV_CELL_SHIFT_KEY,
+ AtomicDataDict.ONSITENV_VECTORS_KEY,
+ AtomicDataDict.ONSITENV_LENGTH_KEY,
+ AtomicDataDict.ONSITENV_ATTRS_KEY,
+ AtomicDataDict.ONSITENV_EMBEDDING_KEY,
+ AtomicDataDict.ONSITENV_FEATURES_KEY,
+ AtomicDataDict.ONSITENV_CUTOFF_KEY,
+}
+
+_DEFAULT_GRAPH_FIELDS: Set[str] = {
+ AtomicDataDict.TOTAL_ENERGY_KEY,
+ AtomicDataDict.STRESS_KEY,
+ AtomicDataDict.VIRIAL_KEY,
+ AtomicDataDict.PBC_KEY,
+ AtomicDataDict.CELL_KEY,
+ AtomicDataDict.BATCH_PTR_KEY,
+ AtomicDataDict.KPOINT_KEY, # new
+ AtomicDataDict.HAMILTONIAN_KEY, # new
+ AtomicDataDict.OVERLAP_KEY, # new
+ AtomicDataDict.ENERGY_EIGENVALUE_KEY, # new
+ AtomicDataDict.ENERGY_WINDOWS_KEY, # new,
+ AtomicDataDict.BAND_WINDOW_KEY # new,
+}
+
+_NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS)
+_EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS)
+_ENV_FIELDS: Set[str] = set(_DEFAULT_ENV_FIELDS)
+_ONSITENV_FIELDS: Set[str] = set(_DEFAULT_ONSITENV_FIELDS)
+_GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS)
+_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS)
+
+
+def register_fields(
+ node_fields: Sequence[str] = [],
+ edge_fields: Sequence[str] = [],
+ env_fields: Sequence[str] = [],
+ onsitenv_fields: Sequence[str] = [],
+ graph_fields: Sequence[str] = [],
+ long_fields: Sequence[str] = [],
+) -> None:
+
+ r"""Register fields as being per-atom, per-edge, or per-frame.
+
+ Args:
+ node_permute_fields: fields that are equivariant to node permutations.
+ edge_permute_fields: fields that are equivariant to edge permutations.
+ """
+
+ node_fields: set = set(node_fields)
+ edge_fields: set = set(edge_fields)
+ env_fields: set = set(env_fields)
+ onsitenv_fields: set = set(onsitenv_fields)
+ graph_fields: set = set(graph_fields)
+ long_fields: set = set(long_fields)
+ allfields = node_fields.union(edge_fields, graph_fields, env_fields, onsitenv_fields)
+ assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields)
+ _NODE_FIELDS.update(node_fields)
+ _EDGE_FIELDS.update(edge_fields)
+ _ENV_FIELDS.update(env_fields)
+ _ONSITENV_FIELDS.update(onsitenv_fields)
+ _GRAPH_FIELDS.update(graph_fields)
+ _LONG_FIELDS.update(long_fields)
+ if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < (
+ len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS)
+ ):
+ raise ValueError(
+ "At least one key was registered as more than one of node, edge, or graph!"
+ )
+
+
+def deregister_fields(*fields: Sequence[str]) -> None:
+ r"""Deregister a field registered with ``register_fields``.
+
+ Silently ignores fields that were never registered to begin with.
+
+ Args:
+ *fields: fields to deregister.
+ """
+ for f in fields:
+ assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field"
+ assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field"
+ assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field"
+ assert f not in _DEFAULT_ENV_FIELDS, "Cannot deregister built-in field"
+ assert f not in _DEFAULT_ONSITENV_FIELDS, "Cannot deregister built-in field"
+ _NODE_FIELDS.discard(f)
+ _EDGE_FIELDS.discard(f)
+ _ENV_FIELDS.discard(f)
+ _ONSITENV_FIELDS.discard(f)
+ _GRAPH_FIELDS.discard(f)
+
+
+def _register_field_prefix(prefix: str) -> None:
+ """Re-register all registered fields as the same type, but with `prefix` added on."""
+ assert prefix.endswith("_")
+ register_fields(
+ node_fields=[prefix + e for e in _NODE_FIELDS],
+ edge_fields=[prefix + e for e in _EDGE_FIELDS],
+ env_fields=[prefix + e for e in _ENV_FIELDS],
+ onsitenv_fields=[prefix + e for e in _ONSITENV_FIELDS],
+ graph_fields=[prefix + e for e in _GRAPH_FIELDS],
+ long_fields=[prefix + e for e in _LONG_FIELDS],
+ )
+
+
+def _process_dict(kwargs, ignore_fields=[]):
+ """Convert a dict of data into correct dtypes/shapes according to key"""
+ # Deal with _some_ dtype issues
+ for k, v in kwargs.items():
+ if k in ignore_fields:
+ continue
+
+ if k in _LONG_FIELDS:
+ # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems)
+ # int32 would pass later checks, but is actually disallowed by torch
+ kwargs[k] = torch.as_tensor(v, dtype=torch.long)
+ elif isinstance(v, bool):
+ kwargs[k] = torch.as_tensor(v)
+ elif isinstance(v, np.ndarray):
+ if np.issubdtype(v.dtype, np.floating):
+ kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype())
+ else:
+ kwargs[k] = torch.as_tensor(v)
+ elif isinstance(v, list):
+ ele_dtype = np.array(v).dtype
+ if np.issubdtype(ele_dtype, np.floating):
+ kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype())
+ else:
+ kwargs[k] = torch.as_tensor(v)
+ elif np.issubdtype(type(v), np.floating):
+ # Force scalars to be tensors with a data dimension
+ # This makes them play well with irreps
+ kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype())
+ elif isinstance(v, torch.Tensor) and len(v.shape) == 0:
+ # ^ this tensor is a scalar; we need to give it
+ # a data dimension to play nice with irreps
+ kwargs[k] = v
+
+ if AtomicDataDict.BATCH_KEY in kwargs:
+ num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1
+ else:
+ num_frames = 1
+
+ for k, v in kwargs.items():
+ if k in ignore_fields:
+ continue
+
+ if len(v.shape) == 0:
+ kwargs[k] = v.unsqueeze(-1)
+ v = kwargs[k]
+
+ if k in set.union(_NODE_FIELDS, _EDGE_FIELDS) and len(v.shape) == 1:
+ kwargs[k] = v.unsqueeze(-1)
+ v = kwargs[k]
+
+ if (
+ k in _NODE_FIELDS
+ and AtomicDataDict.POSITIONS_KEY in kwargs
+ and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0]
+ ):
+ raise ValueError(
+ f"{k} is a node field but has the wrong dimension {v.shape}"
+ )
+ elif (
+ k in _EDGE_FIELDS
+ and AtomicDataDict.EDGE_INDEX_KEY in kwargs
+ and v.shape[0] != kwargs[AtomicDataDict.EDGE_INDEX_KEY].shape[1]
+ ):
+ raise ValueError(
+ f"{k} is a edge field but has the wrong dimension {v.shape}"
+ )
+ elif (
+ k in _ENV_FIELDS
+ and AtomicDataDict.ENV_INDEX_KEY in kwargs
+ and v.shape[0] != kwargs[AtomicDataDict.ENV_INDEX_KEY].shape[1]
+ ):
+ raise ValueError(
+ f"{k} is a env field but has the wrong dimension {v.shape}"
+ )
+ elif (
+ k in _ONSITENV_FIELDS
+ and AtomicDataDict.ONSITENV_INDEX_KEY in kwargs
+ and v.shape[0] != kwargs[AtomicDataDict.ONSITENV_INDEX_KEY].shape[1]
+ ):
+ raise ValueError(
+ f"{k} is a env field but has the wrong dimension {v.shape}"
+ )
+ elif k in _GRAPH_FIELDS:
+ if num_frames > 1 and v.shape[0] != num_frames:
+ raise ValueError(f"Wrong shape for graph property {k}")
+
+
+class AtomicData(Data):
+ """A neighbor graph for points in (periodic triclinic) real space.
+
+ For typical cases either ``from_points`` or ``from_ase`` should be used to
+ construct a AtomicData; they also standardize and check their input much more
+ thoroughly.
+
+ In general, ``node_features`` are features or input information on the nodes that will be fed through and transformed by the network, while ``node_attrs`` are _encodings_ fixed, inherant attributes of the atoms themselves that remain constant through the network.
+ For example, a one-hot _encoding_ of atomic species is a node attribute, while some observed instantaneous property of that atom (current partial charge, for example), would be a feature.
+
+ In general, ``torch.Tensor`` arguments should be of consistant dtype. Numpy arrays will be converted to ``torch.Tensor``s; those of floating point dtype will be converted to ``torch.get_current_dtype()`` regardless of their original precision. Scalar values (Python scalars or ``torch.Tensor``s of shape ``()``) a resized to tensors of shape ``[1]``. Per-atom scalar values should be given with shape ``[N_at, 1]``.
+
+ ``AtomicData`` should be used for all data creation and manipulation outside of the model; inside of the model ``AtomicDataDict.Type`` is used.
+
+ Args:
+ pos (Tensor [n_nodes, 3]): Positions of the nodes.
+ edge_index (LongTensor [2, n_edges]): ``edge_index[0]`` is the per-edge
+ index of the source node and ``edge_index[1]`` is the target node.
+ edge_cell_shift (Tensor [n_edges, 3], optional): which periodic image
+ of the target point each edge goes to, relative to the source point.
+ cell (Tensor [1, 3, 3], optional): the periodic cell for
+ ``edge_cell_shift`` as the three triclinic cell vectors.
+ node_features (Tensor [n_atom, ...]): the input features of the nodes, optional
+ node_attrs (Tensor [n_atom, ...]): the attributes of the nodes, for instance the atom type, optional
+ batch (Tensor [n_atom]): the graph to which the node belongs, optional
+ atomic_numbers (Tensor [n_atom]): optional.
+ atom_type (Tensor [n_atom]): optional.
+ **kwargs: other data, optional.
+ """
+
+ def __init__(
+ self, irreps: Dict[str, e3nn.o3.Irreps] = {}, _validate: bool = True, **kwargs
+ ):
+
+ # empty init needed by get_example
+ if len(kwargs) == 0 and len(irreps) == 0:
+ super().__init__()
+ return
+
+ # Check the keys
+ if _validate:
+ AtomicDataDict.validate_keys(kwargs)
+ _process_dict(kwargs)
+
+ super().__init__(num_nodes=len(kwargs["pos"]), **kwargs)
+
+ if _validate:
+ # Validate shapes
+ assert self.pos.dim() == 2 and self.pos.shape[1] == 3
+ assert self.edge_index.dim() == 2 and self.edge_index.shape[0] == 2
+ if "edge_cell_shift" in self and self.edge_cell_shift is not None:
+ assert self.edge_cell_shift.shape == (self.num_edges, 3)
+ assert self.edge_cell_shift.dtype == self.pos.dtype
+ # TODO: should we add checks for env too ?
+ if "cell" in self and self.cell is not None:
+ assert (self.cell.shape == (3, 3)) or (
+ self.cell.dim() == 3 and self.cell.shape[1:] == (3, 3)
+ )
+ assert self.cell.dtype == self.pos.dtype
+ if "node_features" in self and self.node_features is not None:
+ assert self.node_features.shape[0] == self.num_nodes
+ assert self.node_features.dtype == self.pos.dtype
+ if "node_attrs" in self and self.node_attrs is not None:
+ assert self.node_attrs.shape[0] == self.num_nodes
+ assert self.node_attrs.dtype == self.pos.dtype
+
+ if (
+ AtomicDataDict.ATOMIC_NUMBERS_KEY in self
+ and self.atomic_numbers is not None
+ ):
+ assert self.atomic_numbers.dtype in _TORCH_INTEGER_DTYPES
+ if "batch" in self and self.batch is not None:
+ assert self.batch.dim() == 2 and self.batch.shape[0] == self.num_nodes
+ # Check that there are the right number of cells
+ if "cell" in self and self.cell is not None:
+ cell = self.cell.view(-1, 3, 3)
+ assert cell.shape[0] == self.batch.max() + 1
+
+ # Validate irreps
+ # __*__ is the only way to hide from torch_geometric
+ self.__irreps__ = AtomicDataDict._fix_irreps_dict(irreps)
+ for field, irreps in self.__irreps__:
+ if irreps is not None:
+ assert self[field].shape[-1] == irreps.dim
+
+ @classmethod
+ def from_points(
+ cls,
+ pos=None,
+ r_max: float = None,
+ self_interaction: bool = False,
+ cell=None,
+ pbc: Optional[PBC] = None,
+ er_max: Optional[float] = None,
+ oer_max: Optional[float] = None,
+ **kwargs,
+ ):
+ """Build neighbor graph from points, optionally with PBC.
+
+ Args:
+ pos (np.ndarray/torch.Tensor shape [N, 3]): node positions. If Tensor, must be on the CPU.
+ r_max (float): neighbor cutoff radius.
+ cell (ase.Cell/ndarray [3,3], optional): periodic cell for the points. Defaults to ``None``.
+ pbc (bool or 3-tuple of bool, optional): whether to apply periodic boundary conditions to all or each of
+ the three cell vector directions. Defaults to ``False``.
+ self_interaction (bool, optional): whether to include self edges for points. Defaults to ``False``. Note
+ that edges between the same atom in different periodic images are still included. (See
+ ``strict_self_interaction`` to control this behaviour.)
+ strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the
+ two instances of the atom are in different periodic images. Defaults to True, should be True for most
+ applications.
+ **kwargs (optional): other fields to add. Keys listed in ``AtomicDataDict.*_KEY` will be treated specially.
+ """
+ if pos is None or r_max is None:
+ raise ValueError("pos and r_max must be given.")
+
+ if pbc is None:
+ if cell is not None:
+ raise ValueError(
+ "A cell was provided, but pbc weren't. Please explicitly probide PBC."
+ )
+ # there are no PBC if cell and pbc are not provided
+ pbc = False
+
+ if isinstance(pbc, bool):
+ pbc = (pbc,) * 3
+ else:
+ assert len(pbc) == 3
+
+ # TODO: We can only compute the edge vector one times with the largest radial distance among [r_max, er_max, oer_max]
+
+ pos = torch.as_tensor(pos, dtype=torch.get_default_dtype())
+
+ edge_index, edge_cell_shift, cell = neighbor_list_and_relative_vec(
+ pos=pos,
+ r_max=r_max,
+ self_interaction=self_interaction,
+ cell=cell,
+ reduce=False,
+ atomic_numbers=kwargs.get("atomic_numbers", None),
+ pbc=pbc,
+ )
+
+ # Make torch tensors for data:
+ if cell is not None:
+ kwargs[AtomicDataDict.CELL_KEY] = cell.view(3, 3)
+ kwargs[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = edge_cell_shift
+ if pbc is not None:
+ kwargs[AtomicDataDict.PBC_KEY] = torch.as_tensor(
+ pbc, dtype=torch.bool
+ ).view(3)
+
+ # add env index
+ if er_max is not None:
+ env_index, env_cell_shift, _ = neighbor_list_and_relative_vec(
+ pos=pos,
+ r_max=er_max,
+ self_interaction=self_interaction,
+ cell=cell,
+ reduce=False,
+ atomic_numbers=kwargs.get("atomic_numbers", None),
+ pbc=pbc,
+ )
+
+ if cell is not None:
+ kwargs[AtomicDataDict.ENV_CELL_SHIFT_KEY] = env_cell_shift
+ kwargs[AtomicDataDict.ENV_INDEX_KEY] = env_index
+
+ # add onsitenv index
+ if oer_max is not None:
+ onsitenv_index, onsitenv_cell_shift, _ = neighbor_list_and_relative_vec(
+ pos=pos,
+ r_max=oer_max,
+ self_interaction=self_interaction,
+ cell=cell,
+ reduce=False,
+ atomic_numbers=kwargs.get("atomic_numbers", None),
+ pbc=pbc
+ )
+
+ if cell is not None:
+ kwargs[AtomicDataDict.ONSITENV_CELL_SHIFT_KEY] = onsitenv_cell_shift
+ kwargs[AtomicDataDict.ONSITENV_INDEX_KEY] = onsitenv_index
+
+ return cls(edge_index=edge_index, pos=torch.as_tensor(pos), **kwargs)
+
+ @classmethod
+ def from_ase(
+ cls,
+ atoms,
+ r_max,
+ er_max: Optional[float] = None,
+ oer_max: Optional[float] = None,
+ key_mapping: Optional[Dict[str, str]] = {},
+ include_keys: Optional[list] = [],
+ **kwargs,
+ ):
+ """Build a ``AtomicData`` from an ``ase.Atoms`` object.
+
+ Respects ``atoms``'s ``pbc`` and ``cell``.
+
+ First tries to extract energies and forces from a single-point calculator associated with the ``Atoms`` if one is present and has those fields.
+ If either is not found, the method will look for ``energy``/``energies`` and ``force``/``forces`` in ``atoms.arrays``.
+
+ `get_atomic_numbers()` will be stored as the atomic_numbers attribute.
+
+ Args:
+ atoms (ase.Atoms): the input.
+ r_max (float): neighbor cutoff radius.
+ features (torch.Tensor shape [N, M], optional): per-atom M-dimensional feature vectors. If ``None`` (the
+ default), uses a one-hot encoding of the species present in ``atoms``.
+ include_keys (list): list of additional keys to include in AtomicData aside from the ones defined in
+ ase.calculators.calculator.all_properties. Optional
+ key_mapping (dict): rename ase property name to a new string name. Optional
+ **kwargs (optional): other arguments for the ``AtomicData`` constructor.
+
+ Returns:
+ A ``AtomicData``.
+ """
+ # from nequip.ase import NequIPCalculator
+
+ assert "pos" not in kwargs
+
+ default_args = set(
+ [
+ "numbers",
+ "positions",
+ ] # ase internal names for position and atomic_numbers
+ + ["pbc", "cell", "pos", "r_max", "er_max", "oer_max"] # arguments for from_points method
+ + list(kwargs.keys())
+ )
+ # the keys that are duplicated in kwargs are removed from the include_keys
+ include_keys = list(
+ set(include_keys + ase_all_properties + list(key_mapping.keys()))
+ - default_args
+ )
+
+ km = {
+ "forces": AtomicDataDict.FORCE_KEY,
+ "energy": AtomicDataDict.TOTAL_ENERGY_KEY,
+ }
+ km.update(key_mapping)
+ key_mapping = km
+
+ add_fields = {}
+
+ # Get info from atoms.arrays; lowest priority. copy first
+ add_fields = {
+ key_mapping.get(k, k): v
+ for k, v in atoms.arrays.items()
+ if k in include_keys
+ }
+
+ # Get info from atoms.info; second lowest priority.
+ add_fields.update(
+ {
+ key_mapping.get(k, k): v
+ for k, v in atoms.info.items()
+ if k in include_keys
+ }
+ )
+
+ # if atoms.calc is not None:
+
+ # if isinstance(
+ # atoms.calc, (SinglePointCalculator, SinglePointDFTCalculator)
+ # ):
+ # add_fields.update(
+ # {
+ # key_mapping.get(k, k): deepcopy(v)
+ # for k, v in atoms.calc.results.items()
+ # if k in include_keys
+ # }
+ # )
+ # elif isinstance(atoms.calc, NequIPCalculator):
+ # pass # otherwise the calculator breaks
+ # else:
+ # raise NotImplementedError(
+ # f"`from_ase` does not support calculator {atoms.calc}"
+ # )
+
+ add_fields[AtomicDataDict.ATOMIC_NUMBERS_KEY] = atoms.get_atomic_numbers()
+
+ # cell and pbc in kwargs can override the ones stored in atoms
+ cell = kwargs.pop("cell", atoms.get_cell())
+ pbc = kwargs.pop("pbc", atoms.pbc)
+
+ # handle ASE-style 6 element Voigt order stress
+ for key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY):
+ if key in add_fields:
+ if add_fields[key].shape == (3, 3):
+ # it's already 3x3, do nothing else
+ pass
+ elif add_fields[key].shape == (6,):
+ # it's Voigt order
+ add_fields[key] = voigt_6_to_full_3x3_stress(add_fields[key])
+ else:
+ raise RuntimeError(f"bad shape for {key}")
+
+ return cls.from_points(
+ pos=atoms.positions,
+ r_max=r_max,
+ er_max=er_max,
+ oer_max=oer_max,
+ cell=cell,
+ pbc=pbc,
+ **kwargs,
+ **add_fields,
+ )
+
+ def to_ase(
+ self,
+ type_mapper=None,
+ extra_fields: List[str] = [],
+ ) -> Union[List[ase.Atoms], ase.Atoms]:
+ """Build a (list of) ``ase.Atoms`` object(s) from an ``AtomicData`` object.
+
+ For each unique batch number provided in ``AtomicDataDict.BATCH_KEY``,
+ an ``ase.Atoms`` object is created. If ``AtomicDataDict.BATCH_KEY`` does not
+ exist in self, a single ``ase.Atoms`` object is created.
+
+ Args:
+ type_mapper: if provided, will be used to map ``ATOM_TYPES`` back into
+ elements, if the configuration of the ``type_mapper`` allows.
+ extra_fields: fields other than those handled explicitly (currently
+ those defining the structure as well as energy, per-atom energy,
+ and forces) to include in the output object. Per-atom (per-node)
+ quantities will be included in ``arrays``; per-graph and per-edge
+ quantities will be included in ``info``.
+
+ Returns:
+ A list of ``ase.Atoms`` objects if ``AtomicDataDict.BATCH_KEY`` is in self
+ and is not None. Otherwise, a single ``ase.Atoms`` object is returned.
+ """
+ positions = self.pos
+ edge_index = self[AtomicDataDict.EDGE_INDEX_KEY]
+ if positions.device != torch.device("cpu"):
+ raise TypeError(
+ "Explicitly move this `AtomicData` to CPU using `.to()` before calling `to_ase()`."
+ )
+ if AtomicDataDict.ATOMIC_NUMBERS_KEY in self:
+ atomic_nums = self.atomic_numbers
+ elif type_mapper is not None and type_mapper.has_chemical_symbols:
+ atomic_nums = type_mapper.untransform(self[AtomicDataDict.ATOM_TYPE_KEY])
+ else:
+ warnings.warn(
+ "AtomicData.to_ase(): self didn't contain atomic numbers... using atom_type as atomic numbers instead, but this means the chemical symbols in ASE (outputs) will be wrong"
+ )
+ atomic_nums = self[AtomicDataDict.ATOM_TYPE_KEY]
+ pbc = getattr(self, AtomicDataDict.PBC_KEY, None)
+ cell = getattr(self, AtomicDataDict.CELL_KEY, None)
+ batch = getattr(self, AtomicDataDict.BATCH_KEY, None)
+ energy = getattr(self, AtomicDataDict.TOTAL_ENERGY_KEY, None)
+ energies = getattr(self, AtomicDataDict.PER_ATOM_ENERGY_KEY, None)
+ force = getattr(self, AtomicDataDict.FORCE_KEY, None)
+ do_calc = any(
+ k in self
+ for k in [
+ AtomicDataDict.TOTAL_ENERGY_KEY,
+ AtomicDataDict.FORCE_KEY,
+ AtomicDataDict.PER_ATOM_ENERGY_KEY,
+ AtomicDataDict.STRESS_KEY,
+ ]
+ )
+
+ # exclude those that are special for ASE and that we process seperately
+ special_handling_keys = [
+ AtomicDataDict.POSITIONS_KEY,
+ AtomicDataDict.CELL_KEY,
+ AtomicDataDict.PBC_KEY,
+ AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ AtomicDataDict.TOTAL_ENERGY_KEY,
+ AtomicDataDict.FORCE_KEY,
+ AtomicDataDict.PER_ATOM_ENERGY_KEY,
+ AtomicDataDict.STRESS_KEY,
+ ]
+ assert (
+ len(set(extra_fields).intersection(special_handling_keys)) == 0
+ ), f"Cannot specify keys handled in special ways ({special_handling_keys}) as `extra_fields` for atoms output--- they are output by default"
+
+ if cell is not None:
+ cell = cell.view(-1, 3, 3)
+ if pbc is not None:
+ pbc = pbc.view(-1, 3)
+
+ if batch is not None:
+ n_batches = batch.max() + 1
+ cell = cell.expand(n_batches, 3, 3) if cell is not None else None
+ pbc = pbc.expand(n_batches, 3) if pbc is not None else None
+ else:
+ n_batches = 1
+
+ batch_atoms = []
+ for batch_idx in range(n_batches):
+ if batch is not None:
+ mask = batch == batch_idx
+ mask = mask.view(-1)
+ # if both ends of the edge are in the batch, the edge is in the batch
+ edge_mask = mask[edge_index[0]] & mask[edge_index[1]]
+ else:
+ mask = slice(None)
+ edge_mask = slice(None)
+
+ mol = ase.Atoms(
+ numbers=atomic_nums[mask].view(-1), # must be flat for ASE
+ positions=positions[mask],
+ cell=cell[batch_idx] if cell is not None else None,
+ pbc=pbc[batch_idx] if pbc is not None else None,
+ )
+
+ if do_calc:
+ fields = {}
+ if energies is not None:
+ fields["energies"] = energies[mask].cpu().numpy()
+ if energy is not None:
+ fields["energy"] = energy[batch_idx].cpu().numpy()
+ if force is not None:
+ fields["forces"] = force[mask].cpu().numpy()
+ if AtomicDataDict.STRESS_KEY in self:
+ fields["stress"] = full_3x3_to_voigt_6_stress(
+ self["stress"].view(-1, 3, 3)[batch_idx].cpu().numpy()
+ )
+ mol.calc = SinglePointCalculator(mol, **fields)
+
+ # add other information
+ for key in extra_fields:
+ if key in _NODE_FIELDS:
+ # mask it
+ mol.arrays[key] = (
+ self[key][mask].cpu().numpy().reshape(mask.sum(), -1)
+ )
+ elif key in _EDGE_FIELDS:
+ mol.info[key] = (
+ self[key][edge_mask].cpu().numpy().reshape(edge_mask.sum(), -1)
+ )
+ elif key == AtomicDataDict.EDGE_INDEX_KEY:
+ mol.info[key] = self[key][:, edge_mask].cpu().numpy()
+ elif key in _GRAPH_FIELDS:
+ mol.info[key] = self[key][batch_idx].cpu().numpy().reshape(-1)
+ else:
+ raise RuntimeError(
+ f"Extra field `{key}` isn't registered as node/edge/graph"
+ )
+
+ batch_atoms.append(mol)
+
+ if batch is not None:
+ return batch_atoms
+ else:
+ assert len(batch_atoms) == 1
+ return batch_atoms[0]
+
+ def get_edge_vectors(data: Data) -> torch.Tensor:
+ data = AtomicDataDict.with_edge_vectors(AtomicData.to_AtomicDataDict(data))
+ return data[AtomicDataDict.EDGE_VECTORS_KEY]
+
+ def get_env_vectors(data: Data) -> torch.Tensor:
+ data = AtomicDataDict.with_env_vectors(AtomicData.to_AtomicDataDict(data))
+ return data[AtomicDataDict.ENV_VECTORS_KEY]
+
+ @staticmethod
+ def to_AtomicDataDict(
+ data: Union[Data, Mapping], exclude_keys=tuple()
+ ) -> AtomicDataDict.Type:
+ if isinstance(data, Data):
+ keys = data.keys
+ elif isinstance(data, Mapping):
+ keys = data.keys()
+ else:
+ raise ValueError(f"Invalid data `{repr(data)}`")
+
+ return {
+ k: data[k]
+ for k in keys
+ if (
+ k not in exclude_keys
+ and data[k] is not None
+ and isinstance(data[k], torch.Tensor)
+ )
+ }
+
+ @classmethod
+ def from_AtomicDataDict(cls, data: AtomicDataDict.Type):
+ # it's an AtomicDataDict, so don't validate-- assume valid:
+ return cls(_validate=False, **data)
+
+ @property
+ def irreps(self):
+ return self.__irreps__
+
+ def __cat_dim__(self, key, value):
+ if key == AtomicDataDict.EDGE_INDEX_KEY or key == AtomicDataDict.ENV_INDEX_KEY or key == AtomicDataDict.ONSITENV_INDEX_KEY:
+ return 1 # always cat in the edge dimension
+ elif key in _GRAPH_FIELDS:
+ # graph-level properties and so need a new batch dimension
+ return None
+ else:
+ return 0 # cat along node/edge dimension
+
+ def without_nodes(self, which_nodes):
+ """Return a copy of ``self`` with ``which_nodes`` removed.
+ The returned object may share references to some underlying data tensors with ``self``.
+ Args:
+ which_nodes (index tensor or boolean mask)
+ Returns:
+ A new data object.
+ """
+ which_nodes = torch.as_tensor(which_nodes)
+ if which_nodes.dtype == torch.bool:
+ mask = ~which_nodes
+ else:
+ mask = torch.ones(self.num_nodes, dtype=torch.bool)
+ mask[which_nodes] = False
+ assert mask.shape == (self.num_nodes,)
+ n_keeping = mask.sum()
+
+ # Only keep edges where both from and to are kept
+ edge_mask = mask[self.edge_index[0]] & mask[self.edge_index[1]]
+ if hasattr(self, AtomicDataDict.ENV_INDEX_KEY):
+ env_mask = mask[self.env_index[0]] & mask[self.env_index[1]]
+ if hasattr(self, AtomicDataDict.ONSITENV_INDEX_KEY):
+ onsitenv_mask = mask[self.onsitenv_index[0]] & mask[self.onsitenv_index[1]]
+ # Create an index mapping:
+ new_index = torch.full((self.num_nodes,), -1, dtype=torch.long)
+ new_index[mask] = torch.arange(n_keeping, dtype=torch.long)
+
+ new_dict = {}
+ for k in self.keys:
+ if k == AtomicDataDict.EDGE_INDEX_KEY:
+ new_dict[AtomicDataDict.EDGE_INDEX_KEY] = new_index[
+ self.edge_index[:, edge_mask]
+ ]
+ elif k == AtomicDataDict.EDGE_CELL_SHIFT_KEY:
+ new_dict[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = self.edge_cell_shift[
+ edge_mask
+ ]
+ elif k == AtomicDataDict.CELL_KEY:
+ new_dict[k] = self[k]
+ elif k == AtomicDataDict.ENV_INDEX_KEY:
+ new_dict[AtomicDataDict.ENV_INDEX_KEY] = new_index[
+ self.env_index[:, env_mask]
+ ]
+ elif k == AtomicDataDict.ENV_CELL_SHIFT_KEY:
+ new_dict[AtomicDataDict.ENV_CELL_SHIFT_KEY] = self.env_cell_shift[
+ env_mask
+ ]
+ elif k == AtomicDataDict.ONSITENV_INDEX_KEY:
+ new_dict[AtomicDataDict.ONSITENV_INDEX_KEY] = new_index[
+ self.onsitenv_index[:, onsitenv_mask]
+ ]
+ elif k == AtomicDataDict.ONSITENV_CELL_SHIFT_KEY:
+ new_dict[AtomicDataDict.ONSITENV_CELL_SHIFT_KEY] = self.onsitenv_cell_shift[
+ onsitenv_mask
+ ]
+ else:
+ if isinstance(self[k], torch.Tensor) and len(self[k]) == self.num_nodes:
+ new_dict[k] = self[k][mask]
+ else:
+ new_dict[k] = self[k]
+
+ new_dict["irreps"] = self.__irreps__
+
+ return type(self)(**new_dict)
+
+
+_ERROR_ON_NO_EDGES: bool = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
+assert _ERROR_ON_NO_EDGES in ("true", "false")
+_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true"
+
+
+def neighbor_list_and_relative_vec(
+ pos,
+ r_max,
+ self_interaction=False,
+ reduce=True,
+ atomic_numbers=None,
+ cell=None,
+ pbc=False,
+):
+ """Create neighbor list and neighbor vectors based on radial cutoff.
+
+ Create neighbor list (``edge_index``) and relative vectors
+ (``edge_attr``) based on radial cutoff.
+
+ Edges are given by the following convention:
+ - ``edge_index[0]`` is the *source* (convolution center).
+ - ``edge_index[1]`` is the *target* (neighbor).
+
+ Thus, ``edge_index`` has the same convention as the relative vectors:
+ :math:`\\vec{r}_{source, target}`
+
+ If the input positions are a tensor with ``requires_grad == True``,
+ the output displacement vectors will be correctly attached to the inputs
+ for autograd.
+
+ All outputs are Tensors on the same device as ``pos``; this allows future
+ optimization of the neighbor list on the GPU.
+
+ Args:
+ pos (shape [N, 3]): Positional coordinate; Tensor or numpy array. If Tensor, must be on CPU.
+ r_max (float): Radial cutoff distance for neighbor finding.
+ cell (numpy shape [3, 3]): Cell for periodic boundary conditions. Ignored if ``pbc == False``.
+ pbc (bool or 3-tuple of bool): Whether the system is periodic in each of the three cell dimensions.
+ self_interaction (bool): Whether or not to include same periodic image self-edges in the neighbor list.
+ strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the two
+ instances of the atom are in different periodic images. Defaults to True, should be True for most applications.
+
+ Returns:
+ edge_index (torch.tensor shape [2, num_edges]): List of edges.
+ edge_cell_shift (torch.tensor shape [num_edges, 3]): Relative cell shift
+ vectors. Returned only if cell is not None.
+ cell (torch.Tensor [3, 3]): the cell as a tensor on the correct device.
+ Returned only if cell is not None.
+ """
+ if isinstance(pbc, bool):
+ pbc = (pbc,) * 3
+
+ # Either the position or the cell may be on the GPU as tensors
+ if isinstance(pos, torch.Tensor):
+ temp_pos = pos.detach().cpu().numpy()
+ out_device = pos.device
+ out_dtype = pos.dtype
+ else:
+ temp_pos = np.asarray(pos)
+ out_device = torch.device("cpu")
+ out_dtype = torch.get_default_dtype()
+
+ # Right now, GPU tensors require a round trip
+ if out_device.type != "cpu":
+ warnings.warn(
+ "Currently, neighborlists require a round trip to the CPU. Please pass CPU tensors if possible."
+ )
+
+ # Get a cell on the CPU no matter what
+ if isinstance(cell, torch.Tensor):
+ temp_cell = cell.detach().cpu().numpy()
+ cell_tensor = cell.to(device=out_device, dtype=out_dtype)
+ elif cell is not None:
+ temp_cell = np.asarray(cell)
+ cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype)
+ else:
+ # ASE will "complete" this correctly.
+ temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype)
+ cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype)
+
+ # ASE dependent part
+ temp_cell = ase.geometry.complete_cell(temp_cell)
+
+ first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list(
+ "ijS",
+ pbc,
+ temp_cell,
+ temp_pos,
+ cutoff=float(r_max),
+ self_interaction=self_interaction, # we want edges from atom to itself in different periodic images!
+ use_scaled_positions=False,
+ )
+
+ # Eliminate true self-edges that don't cross periodic boundaries
+ # if not self_interaction:
+ # bad_edge = first_idex == second_idex
+ # bad_edge &= np.all(shifts == 0, axis=1)
+ # keep_edge = ~bad_edge
+ # if _ERROR_ON_NO_EDGES and (not np.any(keep_edge)):
+ # raise ValueError(
+ # f"Every single atom has no neighbors within the cutoff r_max={r_max} (after eliminating self edges, no edges remain in this system)"
+ # )
+ # first_idex = first_idex[keep_edge]
+ # second_idex = second_idex[keep_edge]
+ # shifts = shifts[keep_edge]
+
+ """
+ bond list is: i, j, shift; but i j shift and j i -shift are the same bond. so we need to remove the duplicate bonds.s
+ first for i != j; we only keep i < j; then the j i -shift will be removed.
+ then, for i == j; we only keep i i shift and remove i i -shift.
+ """
+ # 1. for i != j, keep i < j
+ assert atomic_numbers is not None
+ atomic_numbers = torch.as_tensor(atomic_numbers, dtype=torch.long)
+ mask = first_idex <= second_idex
+ first_idex = first_idex[mask]
+ second_idex = second_idex[mask]
+ shifts = shifts[mask]
+
+ # 2. for i == j
+
+ mask = torch.ones(len(first_idex), dtype=torch.bool)
+ mask[first_idex == second_idex] = False
+ # get index bool type ~mask for i == j.
+ o_first_idex = first_idex[~mask]
+ o_second_idex = second_idex[~mask]
+ o_shift = shifts[~mask]
+ o_mask = mask[~mask] # this is all False, with length being the number all the bonds with i == j.
+
+
+ # using the dict key to remove the duplicate bonds, because it is O(1) to check if a key is in the dict.
+ rev_dict = {}
+ for i in range(len(o_first_idex)):
+ key = str(o_first_idex[i])+str(o_shift[i])
+ key_rev = str(o_first_idex[i])+str(-o_shift[i])
+ rev_dict[key] = True
+ # key_rev is the reverse key of key, if key_rev is in the dict, then the bond is duplicate.
+ # so, only when key_rev is not in the dict, we keep the bond. that is when rev_dict.get(key_rev, False) is False, we set o_mast = True.
+ if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)):
+ o_mask[i] = True
+ del rev_dict
+ del o_first_idex
+ del o_second_idex
+ del o_shift
+ mask[~mask] = o_mask
+ del o_mask
+
+ first_idex = torch.LongTensor(first_idex[mask], device=out_device)
+ second_idex = torch.LongTensor(second_idex[mask], device=out_device)
+ shifts = torch.as_tensor(shifts[mask], dtype=out_dtype, device=out_device)
+
+ if not reduce:
+ first_idex, second_idex = torch.cat((first_idex, second_idex), dim=0), torch.cat((second_idex, first_idex), dim=0)
+ shifts = torch.cat((shifts, -shifts), dim=0)
+
+ # Build output:
+ edge_index = torch.vstack(
+ (torch.LongTensor(first_idex), torch.LongTensor(second_idex))
+ )
+
+ return edge_index, shifts, cell_tensor
diff --git a/dptb/data/AtomicDataDict.py b/dptb/data/AtomicDataDict.py
new file mode 100644
index 00000000..6fc45c9b
--- /dev/null
+++ b/dptb/data/AtomicDataDict.py
@@ -0,0 +1,233 @@
+"""nequip.data.jit: TorchScript functions for dealing with AtomicData.
+
+These TorchScript functions operate on ``Dict[str, torch.Tensor]`` representations
+of the ``AtomicData`` class which are produced by ``AtomicData.to_AtomicDataDict()``.
+
+Authors: Albert Musaelian
+"""
+from typing import Dict, Any
+
+import torch
+import torch.jit
+
+from e3nn import o3
+
+# Make the keys available in this module
+from ._keys import * # noqa: F403, F401
+
+# Also import the module to use in TorchScript, this is a hack to avoid bug:
+# https://github.com/pytorch/pytorch/issues/52312
+from . import _keys
+
+# Define a type alias
+Type = Dict[str, torch.Tensor]
+
+
+def validate_keys(keys, graph_required=True):
+ # Validate combinations
+ if graph_required:
+ if not (_keys.POSITIONS_KEY in keys and _keys.EDGE_INDEX_KEY in keys):
+ raise KeyError("At least pos and edge_index must be supplied")
+ if _keys.EDGE_CELL_SHIFT_KEY in keys and "cell" not in keys:
+ raise ValueError("If `edge_cell_shift` given, `cell` must be given.")
+
+
+_SPECIAL_IRREPS = [None]
+
+
+def _fix_irreps_dict(d: Dict[str, Any]):
+ return {k: (i if i in _SPECIAL_IRREPS else o3.Irreps(i)) for k, i in d.items()}
+
+
+def _irreps_compatible(ir1: Dict[str, o3.Irreps], ir2: Dict[str, o3.Irreps]):
+ return all(ir1[k] == ir2[k] for k in ir1 if k in ir2)
+
+
+@torch.jit.script
+def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type:
+ """Compute the edge displacement vectors for a graph.
+
+ If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this
+ method will return edge vectors correctly connected in the autograd graph.
+
+ Returns:
+ Tensor [n_edges, 3] edge displacement vectors
+ """
+ if _keys.EDGE_VECTORS_KEY in data:
+ if with_lengths and _keys.EDGE_LENGTH_KEY not in data:
+ data[_keys.EDGE_LENGTH_KEY] = torch.linalg.norm(
+ data[_keys.EDGE_VECTORS_KEY], dim=-1
+ )
+
+ return data
+ else:
+ # Build it dynamically
+ # Note that this is
+ # (1) backwardable, because everything (pos, cell, shifts)
+ # is Tensors.
+ # (2) works on a Batch constructed from AtomicData
+ pos = data[_keys.POSITIONS_KEY]
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
+ if _keys.CELL_KEY in data:
+ # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
+ # -1 gives a batch dim no matter what
+ cell = data[_keys.CELL_KEY].view(-1, 3, 3)
+ edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY]
+ if cell.shape[0] > 1:
+ batch = data[_keys.BATCH_KEY]
+ # Cell has a batch dimension
+ # note the ASE cell vectors as rows convention
+ edge_vec = edge_vec + torch.einsum(
+ "ni,nij->nj", edge_cell_shift, cell[batch[edge_index[0]]]
+ )
+ # TODO: is there a more efficient way to do the above without
+ # creating an [n_edge] and [n_edge, 3, 3] tensor?
+ else:
+ # Cell has either no batch dimension, or a useless one,
+ # so we can avoid creating the large intermediate cell tensor.
+ # Note that we do NOT check that the batch array, if it is present,
+ # is trivial — but this does need to be consistent.
+ edge_vec = edge_vec + torch.einsum(
+ "ni,ij->nj",
+ edge_cell_shift,
+ cell.squeeze(0), # remove batch dimension
+ )
+
+ data[_keys.EDGE_VECTORS_KEY] = edge_vec
+ if with_lengths:
+ data[_keys.EDGE_LENGTH_KEY] = torch.linalg.norm(edge_vec, dim=-1)
+ return data
+
+@torch.jit.script
+def with_env_vectors(data: Type, with_lengths: bool = True) -> Type:
+ """Compute the edge displacement vectors for a graph.
+
+ If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this
+ method will return edge vectors correctly connected in the autograd graph.
+
+ Returns:
+ Tensor [n_edges, 3] edge displacement vectors
+ """
+ if _keys.ENV_VECTORS_KEY in data:
+ if with_lengths and _keys.ENV_LENGTH_KEY not in data:
+ data[_keys.ENV_LENGTH_KEY] = torch.linalg.norm(
+ data[_keys.ENV_VECTORS_KEY], dim=-1
+ )
+ return data
+ else:
+ # Build it dynamically
+ # Note that this is
+ # (1) backwardable, because everything (pos, cell, shifts)
+ # is Tensors.
+ # (2) works on a Batch constructed from AtomicData
+ pos = data[_keys.POSITIONS_KEY]
+ env_index = data[_keys.ENV_INDEX_KEY]
+ env_vec = pos[env_index[1]] - pos[env_index[0]]
+ if _keys.CELL_KEY in data:
+ # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
+ # -1 gives a batch dim no matter what
+ cell = data[_keys.CELL_KEY].view(-1, 3, 3)
+ env_cell_shift = data[_keys.ENV_CELL_SHIFT_KEY]
+ if cell.shape[0] > 1:
+ batch = data[_keys.BATCH_KEY]
+ # Cell has a batch dimension
+ # note the ASE cell vectors as rows convention
+ env_vec = env_vec + torch.einsum(
+ "ni,nij->nj", env_cell_shift, cell[batch[env_index[0]]]
+ )
+ # TODO: is there a more efficient way to do the above without
+ # creating an [n_edge] and [n_edge, 3, 3] tensor?
+ else:
+ # Cell has either no batch dimension, or a useless one,
+ # so we can avoid creating the large intermediate cell tensor.
+ # Note that we do NOT check that the batch array, if it is present,
+ # is trivial — but this does need to be consistent.
+ env_vec = env_vec + torch.einsum(
+ "ni,ij->nj",
+ env_cell_shift,
+ cell.squeeze(0), # remove batch dimension
+ )
+ data[_keys.ENV_VECTORS_KEY] = env_vec
+ if with_lengths:
+ data[_keys.ENV_LENGTH_KEY] = torch.linalg.norm(env_vec, dim=-1)
+ return data
+
+@torch.jit.script
+def with_onsitenv_vectors(data: Type, with_lengths: bool = True) -> Type:
+ """Compute the edge displacement vectors for a graph.
+
+ If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this
+ method will return edge vectors correctly connected in the autograd graph.
+
+ Returns:
+ Tensor [n_edges, 3] edge displacement vectors
+ """
+ if _keys.ONSITENV_VECTORS_KEY in data:
+ if with_lengths and _keys.ONSITENV_LENGTH_KEY not in data:
+ data[_keys.ONSITENV_LENGTH_KEY] = torch.linalg.norm(
+ data[_keys.ONSITENV_VECTORS_KEY], dim=-1
+ )
+ return data
+ else:
+ # Build it dynamically
+ # Note that this is
+ # (1) backwardable, because everything (pos, cell, shifts)
+ # is Tensors.
+ # (2) works on a Batch constructed from AtomicData
+ pos = data[_keys.POSITIONS_KEY]
+ env_index = data[_keys.ONSITENV_INDEX_KEY]
+ env_vec = pos[env_index[1]] - pos[env_index[0]]
+ if _keys.CELL_KEY in data:
+ # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
+ # -1 gives a batch dim no matter what
+ cell = data[_keys.CELL_KEY].view(-1, 3, 3)
+ env_cell_shift = data[_keys.ONSITENV_CELL_SHIFT_KEY]
+ if cell.shape[0] > 1:
+ batch = data[_keys.BATCH_KEY]
+ # Cell has a batch dimension
+ # note the ASE cell vectors as rows convention
+ env_vec = env_vec + torch.einsum(
+ "ni,nij->nj", env_cell_shift, cell[batch[env_index[0]]]
+ )
+ # TODO: is there a more efficient way to do the above without
+ # creating an [n_edge] and [n_edge, 3, 3] tensor?
+ else:
+ # Cell has either no batch dimension, or a useless one,
+ # so we can avoid creating the large intermediate cell tensor.
+ # Note that we do NOT check that the batch array, if it is present,
+ # is trivial — but this does need to be consistent.
+ env_vec = env_vec + torch.einsum(
+ "ni,ij->nj",
+ env_cell_shift,
+ cell.squeeze(0), # remove batch dimension
+ )
+ data[_keys.ONSITENV_VECTORS_KEY] = env_vec
+ if with_lengths:
+ data[_keys.ONSITENV_LENGTH_KEY] = torch.linalg.norm(env_vec, dim=-1)
+ return data
+
+
+@torch.jit.script
+def with_batch(data: Type) -> Type:
+ """Get batch Tensor.
+
+ If this AtomicDataPrimitive has no ``batch``, one of all zeros will be
+ allocated and returned.
+ """
+ if _keys.BATCH_KEY in data:
+ return data
+ else:
+ pos = data[_keys.POSITIONS_KEY]
+ batch = torch.zeros(len(pos), dtype=torch.long, device=pos.device)
+ data[_keys.BATCH_KEY] = batch
+ # ugly way to make a tensor of [0, len(pos)], but it avoids transfers or casts
+ data[_keys.BATCH_PTR_KEY] = torch.arange(
+ start=0,
+ end=len(pos) + 1,
+ step=len(pos),
+ dtype=torch.long,
+ device=pos.device,
+ )
+
+ return data
diff --git a/dptb/data/__init__.py b/dptb/data/__init__.py
new file mode 100644
index 00000000..2efca699
--- /dev/null
+++ b/dptb/data/__init__.py
@@ -0,0 +1,49 @@
+from .AtomicData import (
+ AtomicData,
+ PBC,
+ register_fields,
+ deregister_fields,
+ _register_field_prefix,
+ _NODE_FIELDS,
+ _EDGE_FIELDS,
+ _GRAPH_FIELDS,
+ _LONG_FIELDS,
+)
+from .dataset import (
+ AtomicDataset,
+ AtomicInMemoryDataset,
+ NpzDataset,
+ ASEDataset,
+ HDF5Dataset,
+ ABACUSDataset,
+ ABACUSInMemoryDataset,
+ DefaultDataset
+)
+from .dataloader import DataLoader, Collater, PartialSampler
+from .build import dataset_from_config
+from .test_data import EMTTestDataset
+
+__all__ = [
+ AtomicData,
+ PBC,
+ register_fields,
+ deregister_fields,
+ _register_field_prefix,
+ AtomicDataset,
+ AtomicInMemoryDataset,
+ NpzDataset,
+ ASEDataset,
+ HDF5Dataset,
+ ABACUSDataset,
+ ABACUSInMemoryDataset,
+ DefaultDataset,
+ DataLoader,
+ Collater,
+ PartialSampler,
+ dataset_from_config,
+ _NODE_FIELDS,
+ _EDGE_FIELDS,
+ _GRAPH_FIELDS,
+ _LONG_FIELDS,
+ EMTTestDataset,
+]
diff --git a/dptb/data/_keys.py b/dptb/data/_keys.py
new file mode 100644
index 00000000..187d4c79
--- /dev/null
+++ b/dptb/data/_keys.py
@@ -0,0 +1,125 @@
+"""Keys for dictionaries/AtomicData objects.
+
+This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module.
+"""
+
+import sys
+from typing import List
+
+if sys.version_info[1] >= 8:
+ from typing import Final
+else:
+ from typing_extensions import Final
+
+# == Define allowed keys as constants ==
+# The positions of the atoms in the system
+POSITIONS_KEY: Final[str] = "pos"
+# The [2, n_edge] index tensor giving center -> neighbor relations
+EDGE_INDEX_KEY: Final[str] = "edge_index"
+# The [2, n_env] index tensor giving center -> neighbor relations
+ENV_INDEX_KEY: Final[str] = "env_index"
+# The [2, n_onsitenv] index tensor giving center -> neighbor relations
+ONSITENV_INDEX_KEY: Final[str] = "onsitenv_index"
+# A [n_edge, 3] tensor of how many periodic cells each env crosses in each cell vector
+ENV_CELL_SHIFT_KEY: Final[str] = "env_cell_shift"
+# A [n_edge, 3] tensor of how many periodic cells each edge crosses in each cell vector
+EDGE_CELL_SHIFT_KEY: Final[str] = "edge_cell_shift"
+# [n_batch, 3, 3] or [3, 3] tensor where rows are the cell vectors
+ONSITENV_CELL_SHIFT_KEY: Final[str] = "onsitenv_cell_shift"
+# [n_batch, 3, 3] or [3, 3] tensor where rows are the cell vectors
+CELL_KEY: Final[str] = "cell"
+# [n_kpoints, 3] or [n_batch, nkpoints, 3] tensor
+KPOINT_KEY = "kpoint"
+
+HAMILTONIAN_KEY = "hamiltonian"
+
+OVERLAP_KEY = "overlap"
+# [n_batch, 3] bool tensor
+PBC_KEY: Final[str] = "pbc"
+# [n_atom, 1] long tensor
+ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers"
+# [n_atom, 1] long tensor
+ATOM_TYPE_KEY: Final[str] = "atom_types"
+# [n_batch, n_kpoint, n_orb]
+ENERGY_EIGENVALUE_KEY: Final[str] = "eigenvalue"
+
+# [n_batch, 2]
+ENERGY_WINDOWS_KEY = "ewindow"
+BAND_WINDOW_KEY = "bwindow"
+
+BASIC_STRUCTURE_KEYS: Final[List[str]] = [
+ POSITIONS_KEY,
+ EDGE_INDEX_KEY,
+ EDGE_CELL_SHIFT_KEY,
+ CELL_KEY,
+ PBC_KEY,
+ ATOM_TYPE_KEY,
+ ATOMIC_NUMBERS_KEY,
+]
+
+# A [n_edge, 3] tensor of displacement vectors associated to edges
+EDGE_VECTORS_KEY: Final[str] = "edge_vectors"
+# A [n_edge, 3] tensor of displacement vectors associated to envs
+ENV_VECTORS_KEY: Final[str] = "env_vectors"
+# A [n_edge, 3] tensor of displacement vectors associated to onsitenvs
+ONSITENV_VECTORS_KEY: Final[str] = "onsitenv_vectors"
+# A [n_edge] tensor of the lengths of EDGE_VECTORS
+EDGE_LENGTH_KEY: Final[str] = "edge_lengths"
+# A [n_edge] tensor of the lengths of ENV_VECTORS
+ENV_LENGTH_KEY: Final[str] = "env_lengths"
+# A [n_edge] tensor of the lengths of ONSITENV_VECTORS
+ONSITENV_LENGTH_KEY: Final[str] = "onsitenv_lengths"
+# [n_edge, dim] (possibly equivariant) attributes of each edge
+EDGE_ATTRS_KEY: Final[str] = "edge_attrs"
+ENV_ATTRS_KEY: Final[str] = "env_attrs"
+ONSITENV_ATTRS_KEY: Final[str] = "onsitenv_attrs"
+# [n_edge, dim] invariant embedding of the edges
+EDGE_EMBEDDING_KEY: Final[str] = "edge_embedding"
+ENV_EMBEDDING_KEY: Final[str] = "env_embedding"
+ONSITENV_EMBEDDING_KEY: Final[str] = "onsitenv_embedding"
+EDGE_FEATURES_KEY: Final[str] = "edge_features"
+ENV_FEATURES_KEY: Final[str] = "env_features"
+ONSITENV_FEATURES_KEY: Final[str] = "onsitenv_features"
+# [n_edge, 1] invariant of the radial cutoff envelope for each edge, allows reuse of cutoff envelopes
+EDGE_CUTOFF_KEY: Final[str] = "edge_cutoff"
+# [n_edge, 1] invariant of the radial cutoff envelope for each env edge, allows reuse of cutoff envelopes
+ENV_CUTOFF_KEY: Final[str] = "env_cutoff"
+# [n_edge, 1] invariant of the radial cutoff envelope for each onsitenv edge, allows reuse of cutoff envelopes
+ONSITENV_CUTOFF_KEY: Final[str] = "onsitenv_cutoff"
+# edge energy as in Allegro
+EDGE_ENERGY_KEY: Final[str] = "edge_energy"
+EDGE_OVERLAP_KEY: Final[str] = "edge_overlap"
+NODE_OVERLAP_KEY: Final[str] = "node_overlap"
+EDGE_HAMILTONIAN_KEY: Final[str] = "edge_hamiltonian"
+NODE_HAMILTONIAN_KEY: Final[str] = "node_hamiltonian"
+
+NODE_FEATURES_KEY: Final[str] = "node_features"
+NODE_ATTRS_KEY: Final[str] = "node_attrs"
+EDGE_TYPE_KEY: Final[str] = "edge_type"
+
+PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy"
+TOTAL_ENERGY_KEY: Final[str] = "total_energy"
+FORCE_KEY: Final[str] = "forces"
+PARTIAL_FORCE_KEY: Final[str] = "partial_forces"
+STRESS_KEY: Final[str] = "stress"
+VIRIAL_KEY: Final[str] = "virial"
+
+ALL_ENERGY_KEYS: Final[List[str]] = [
+ EDGE_ENERGY_KEY,
+ PER_ATOM_ENERGY_KEY,
+ TOTAL_ENERGY_KEY,
+ FORCE_KEY,
+ PARTIAL_FORCE_KEY,
+ STRESS_KEY,
+ VIRIAL_KEY,
+]
+
+BATCH_KEY: Final[str] = "batch"
+BATCH_PTR_KEY: Final[str] = "ptr"
+
+# Make a list of allowed keys
+ALLOWED_KEYS: List[str] = [
+ getattr(sys.modules[__name__], k)
+ for k in sys.modules[__name__].__dict__.keys()
+ if k.endswith("_KEY")
+]
diff --git a/dptb/data/build.py b/dptb/data/build.py
new file mode 100644
index 00000000..f419c15f
--- /dev/null
+++ b/dptb/data/build.py
@@ -0,0 +1,191 @@
+import inspect
+import os
+from copy import deepcopy
+import glob
+from importlib import import_module
+
+from dptb.data.dataset import DefaultDataset
+from dptb import data
+from dptb.data.transforms import TypeMapper, OrbitalMapper
+from dptb.data import AtomicDataset, register_fields
+from dptb.utils import instantiate, get_w_prefix
+from dptb.utils.tools import j_loader
+from dptb.utils.argcheck import normalize_setinfo
+
+
+def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
+ """initialize database based on a config instance
+
+ It needs dataset type name (case insensitive),
+ and all the parameters needed in the constructor.
+
+ Examples see tests/data/test_dataset.py TestFromConfig
+ and tests/datasets/test_simplest.py
+
+ Args:
+
+ config (dict, nequip.utils.Config): dict/object that store all the parameters
+ prefix (str): Optional. The prefix of all dataset parameters
+
+ Return:
+
+ dataset (nequip.data.AtomicDataset)
+ """
+
+ config_dataset = config.get(prefix, None)
+ if config_dataset is None:
+ raise KeyError(f"Dataset with prefix `{prefix}` isn't present in this config!")
+
+ if inspect.isclass(config_dataset):
+ # user define class
+ class_name = config_dataset
+ else:
+ try:
+ module_name = ".".join(config_dataset.split(".")[:-1])
+ class_name = ".".join(config_dataset.split(".")[-1:])
+ class_name = getattr(import_module(module_name), class_name)
+ except Exception:
+ # ^ TODO: don't catch all Exception
+ # default class defined in nequip.data or nequip.dataset
+ dataset_name = config_dataset.lower()
+
+ class_name = None
+ for k, v in inspect.getmembers(data, inspect.isclass):
+ if k.endswith("Dataset"):
+ if k.lower() == dataset_name:
+ class_name = v
+ if k[:-7].lower() == dataset_name:
+ class_name = v
+ elif k.lower() == dataset_name:
+ class_name = v
+
+ if class_name is None:
+ raise NameError(f"dataset type {dataset_name} does not exists")
+
+ # if dataset r_max is not found, use the universal r_max
+ atomicdata_options_key = "AtomicData_options"
+ prefixed_eff_key = f"{prefix}_{atomicdata_options_key}"
+ config[prefixed_eff_key] = get_w_prefix(
+ atomicdata_options_key, {}, prefix=prefix, arg_dicts=config
+ )
+ config[prefixed_eff_key]["r_max"] = get_w_prefix(
+ "r_max",
+ prefix=prefix,
+ arg_dicts=[config[prefixed_eff_key], config],
+ )
+
+ config[prefixed_eff_key]["er_max"] = get_w_prefix(
+ "er_max",
+ prefix=prefix,
+ arg_dicts=[config[prefixed_eff_key], config],
+ )
+
+ config[prefixed_eff_key]["oer_max"] = get_w_prefix(
+ "oer_max",
+ prefix=prefix,
+ arg_dicts=[config[prefixed_eff_key], config],
+ )
+
+ # Build a TypeMapper from the config
+ type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config)
+
+ # Register fields:
+ # This might reregister fields, but that's OK:
+ instantiate(register_fields, all_args=config)
+
+ instance, _ = instantiate(
+ class_name,
+ prefix=prefix,
+ positional_args={"type_mapper": type_mapper},
+ optional_args=config,
+ )
+
+ return instance
+
+
+def build_dataset(set_options, common_options):
+
+ dataset_type = set_options.get("type", "DefaultDataset")
+
+ # input in set_option for Default Dataset:
+ # "root": main dir storing all trajectory folders.
+ # that is, each subfolder of root contains a trajectory.
+ # "prefix": optional, load selected trajectory folders.
+ # "get_Hamiltonian": load the Hamiltonian file to edges of the graph or not.
+ # "get_eigenvalues": load the eigenvalues to the graph or not.
+ # "setinfo": MUST HAVE, the name of the json file used to build dataset.
+ # Example:
+ # "train": {
+ # "type": "DefaultDataset",
+ # "root": "foo/bar/data_files_here",
+ # "prefix": "traj",
+ # "setinfo": "with_pbc.json"
+ # }
+ if dataset_type == "DefaultDataset":
+ # See if we can get a OrbitalMapper.
+ if "basis" in common_options:
+ idp = OrbitalMapper(common_options["basis"])
+ else:
+ idp = None
+
+ # Explore the dataset's folder structure.
+ root = set_options["root"]
+ prefix = set_options.get("prefix", None)
+ include_folders = []
+ for dir_name in os.listdir(root):
+ dir_path = os.path.join(root, dir_name)
+ if os.path.isdir(dir_path):
+ # If the `processed_dataset` or other folder is here too, they do not have the proper traj data files.
+ # And we will have problem in generating TrajData!
+ # So we test it here: the data folder must have `.dat` or `.traj` file.
+ if glob.glob(os.path.join(dir_path, '*.dat')) or glob.glob(os.path.join(dir_path, '*.traj')):
+ if prefix is not None:
+ if dir_name[:len(prefix)] == prefix:
+ include_folders.append(dir_name)
+ else:
+ include_folders.append(dir_name)
+
+ # We need to check the `setinfo.json` very carefully here.
+ # Different `setinfo` points to different dataset,
+ # even if the data files in `root` are basically the same.
+ info_files = {}
+
+ # See if a public info is provided.
+ if "info.json" in os.listdir(root):
+ public_info = j_loader(os.path.join(root, "info.json"))
+ public_info = normalize_setinfo(public_info)
+ print("A public `info.json` file is provided, and will be used by the subfolders who do not have their own `info.json` file.")
+ else:
+ public_info = None
+
+ # Load info in each trajectory folders seperately.
+ for file in include_folders:
+ if "info.json" in os.listdir(os.path.join(root, file)):
+ # use info provided in this trajectory.
+ info = j_loader(os.path.join(root, file, "info.json"))
+ info = normalize_setinfo(info)
+ info_files[file] = info
+ elif public_info is not None:
+ # use public info instead
+ # yaml will not dump correctly if this is not a deepcopy.
+ info_files[file] = deepcopy(public_info)
+ else:
+ # no info for this file
+ raise Exception(f"info.json is not properly provided for `{file}`.")
+
+ # We will sort the info_files here.
+ # The order itself is not important, but must be consistant for the same list.
+ info_files = {key: info_files[key] for key in sorted(info_files)}
+
+ dataset = DefaultDataset(
+ root=root,
+ type_mapper=idp,
+ get_Hamiltonian=set_options.get("get_Hamiltonian", False),
+ get_eigenvalues=set_options.get("get_eigenvalues", False),
+ info_files = info_files
+ )
+
+ else:
+ raise ValueError(f"Not support dataset type: {type}.")
+
+ return dataset
diff --git a/dptb/data/dataloader.py b/dptb/data/dataloader.py
new file mode 100644
index 00000000..1f383f1a
--- /dev/null
+++ b/dptb/data/dataloader.py
@@ -0,0 +1,163 @@
+from typing import List, Optional, Iterator
+
+import torch
+from torch.utils.data import Sampler
+
+from dptb.utils.torch_geometric import Batch, Data, Dataset
+
+class Collater(object):
+ """Collate a list of ``AtomicData``.
+
+ Args:
+ exclude_keys: keys to ignore in the input, not copying to the output
+ """
+
+ def __init__(
+ self,
+ exclude_keys: List[str] = [],
+ ):
+ self._exclude_keys = set(exclude_keys)
+
+ @classmethod
+ def for_dataset(
+ cls,
+ dataset,
+ exclude_keys: List[str] = [],
+ ):
+ """Construct a collater appropriate to ``dataset``."""
+ return cls(
+ exclude_keys=exclude_keys,
+ )
+
+ def collate(self, batch: List[Data]) -> Batch:
+ """Collate a list of data"""
+ out = Batch.from_data_list(batch, exclude_keys=self._exclude_keys)
+ return out
+
+ def __call__(self, batch: List[Data]) -> Batch:
+ """Collate a list of data"""
+ return self.collate(batch)
+
+ @property
+ def exclude_keys(self):
+ return list(self._exclude_keys)
+
+
+class DataLoader(torch.utils.data.DataLoader):
+ def __init__(
+ self,
+ dataset,
+ batch_size: int = 1,
+ shuffle: bool = False,
+ exclude_keys: List[str] = [],
+ **kwargs,
+ ):
+ if "collate_fn" in kwargs:
+ del kwargs["collate_fn"]
+
+ super(DataLoader, self).__init__(
+ dataset,
+ batch_size,
+ shuffle,
+ collate_fn=Collater.for_dataset(dataset, exclude_keys=exclude_keys),
+ **kwargs,
+ )
+
+
+class PartialSampler(Sampler[int]):
+ r"""Samples elements without replacement, but divided across a number of calls to `__iter__`.
+
+ To ensure deterministic reproducibility and restartability, dataset permutations are generated
+ from a combination of the overall seed and the epoch number. As a result, the caller must
+ tell this sampler the epoch number before each time `__iter__` is called by calling
+ `my_partial_sampler.step_epoch(epoch_number_about_to_run)` each time.
+
+ This sampler decouples epochs from the dataset size and cycles through the dataset over as
+ many (partial) epochs as it may take. As a result, the _dataset_ epoch can change partway
+ through a training epoch.
+
+ Args:
+ data_source (Dataset): dataset to sample from
+ shuffle (bool): whether to shuffle the dataset each time the _entire_ dataset is consumed
+ num_samples_per_epoch (int): number of samples to draw in each call to `__iter__`.
+ If `None`, defaults to `len(data_source)`.
+ generator (Generator): Generator used in sampling.
+ """
+ data_source: Dataset
+ num_samples_per_epoch: int
+ shuffle: bool
+ _epoch: int
+ _prev_epoch: int
+
+ def __init__(
+ self,
+ data_source: Dataset,
+ shuffle: bool = True,
+ num_samples_per_epoch: Optional[int] = None,
+ generator=None,
+ ) -> None:
+ self.data_source = data_source
+ self.shuffle = shuffle
+ if num_samples_per_epoch is None:
+ num_samples_per_epoch = self.num_samples_total
+ self.num_samples_per_epoch = num_samples_per_epoch
+ assert self.num_samples_per_epoch <= self.num_samples_total
+ assert self.num_samples_per_epoch >= 1
+ self.generator = generator
+ self._epoch = None
+ self._prev_epoch = None
+
+ @property
+ def num_samples_total(self) -> int:
+ # dataset size might change at runtime
+ return len(self.data_source)
+
+ def step_epoch(self, epoch: int) -> None:
+ self._epoch = epoch
+
+ def __iter__(self) -> Iterator[int]:
+ assert self._epoch is not None
+ assert (self._prev_epoch is None) or (self._epoch == self._prev_epoch + 1)
+ assert self._epoch >= 0
+
+ full_epoch_i, start_sample_i = divmod(
+ # how much data we've already consumed:
+ self._epoch * self.num_samples_per_epoch,
+ # how much data there is the dataset:
+ self.num_samples_total,
+ )
+
+ if self.shuffle:
+ temp_rng = torch.Generator()
+ # Get new randomness for each _full_ time through the dataset
+ # This is deterministic w.r.t. the combination of dataset seed and epoch number
+ # Both of which persist across restarts
+ # (initial_seed() is restored by set_state())
+ temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i)
+ full_order_this = torch.randperm(self.num_samples_total, generator=temp_rng)
+ # reseed the generator for the _next_ epoch to get the shuffled order of the
+ # _next_ dataset epoch to pad out this one for completing any partial batches
+ # at the end:
+ temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i + 1)
+ full_order_next = torch.randperm(self.num_samples_total, generator=temp_rng)
+ del temp_rng
+ else:
+ full_order_this = torch.arange(self.num_samples_total)
+ # without shuffling, the next epoch has the same sampling order as this one:
+ full_order_next = full_order_this
+
+ full_order = torch.cat((full_order_this, full_order_next), dim=0)
+ del full_order_next, full_order_this
+
+ this_segment_indexes = full_order[
+ start_sample_i : start_sample_i + self.num_samples_per_epoch
+ ]
+ # because we cycle into indexes from the next dataset epoch,
+ # we should _always_ be able to get num_samples_per_epoch
+ assert len(this_segment_indexes) == self.num_samples_per_epoch
+ yield from this_segment_indexes
+
+ self._prev_epoch = self._epoch
+
+ def __len__(self) -> int:
+ return self.num_samples_per_epoch
diff --git a/dptb/data/dataset/__init__.py b/dptb/data/dataset/__init__.py
new file mode 100644
index 00000000..c8dca670
--- /dev/null
+++ b/dptb/data/dataset/__init__.py
@@ -0,0 +1,21 @@
+from ._base_datasets import AtomicDataset, AtomicInMemoryDataset
+from ._ase_dataset import ASEDataset
+from ._npz_dataset import NpzDataset
+from ._hdf5_dataset import HDF5Dataset
+from ._abacus_dataset import ABACUSDataset, ABACUSInMemoryDataset
+from ._deeph_dataset import DeePHE3Dataset
+from ._default_dataset import DefaultDataset
+
+
+__all__ = [
+ DefaultDataset,
+ DeePHE3Dataset,
+ ABACUSInMemoryDataset,
+ ABACUSDataset,
+ ASEDataset,
+ AtomicDataset,
+ AtomicInMemoryDataset,
+ NpzDataset,
+ HDF5Dataset
+ ]
+
diff --git a/dptb/data/dataset/_abacus_dataset.py b/dptb/data/dataset/_abacus_dataset.py
new file mode 100644
index 00000000..bcdb1c62
--- /dev/null
+++ b/dptb/data/dataset/_abacus_dataset.py
@@ -0,0 +1,109 @@
+from typing import Dict, Any, List, Callable, Union, Optional
+import os
+
+import numpy as np
+import h5py
+
+import torch
+
+from .. import (
+ AtomicData,
+ AtomicDataDict,
+)
+
+from ..transforms import TypeMapper, OrbitalMapper
+from ._base_datasets import AtomicDataset, AtomicInMemoryDataset
+#from dptb.nn.hamiltonian import E3Hamiltonian
+from dptb.data.interfaces.ham_to_feature import ham_block_to_feature
+
+orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"}
+
+def _abacus_h5_reader(h5file_path, AtomicData_options):
+ data = h5py.File(h5file_path, "r")
+ atomic_data = AtomicData.from_points(
+ pos = data["pos"][:],
+ cell = data["cell"][:],
+ atomic_numbers = data["atomic_numbers"][:],
+ **AtomicData_options,
+ )
+ if "hamiltonian_blocks" in data:
+ basis = {}
+ for key, value in data["basis"].items():
+ basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)]
+ idp = OrbitalMapper(basis)
+ # e3 = E3Hamiltonian(idp=idp, decompose=True)
+ ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
+ # with torch.no_grad():
+ # atomic_data = e3(atomic_data.to_dict())
+ # atomic_data = AtomicData.from_dict(atomic_data)
+
+ if "eigenvalues" in data and "kpionts" in data:
+ atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoints"][:], dtype=torch.get_default_dtype())
+ atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalues"][:], dtype=torch.get_default_dtype())
+ return atomic_data
+
+# Lazy loading class, built for large dataset.
+
+class ABACUSDataset(AtomicDataset):
+
+ def __init__(
+ self,
+ root: str,
+ preprocess_dir: str,
+ AtomicData_options: Dict[str, Any] = {},
+ type_mapper: Optional[TypeMapper] = None,
+ ):
+ super().__init__(root=root, type_mapper=type_mapper)
+ self.preprocess_dir = preprocess_dir
+ self.file_name = np.loadtxt(os.path.join(self.preprocess_dir, 'AtomicData_file.txt'), dtype=str)
+ self.AtomicData_options = AtomicData_options
+ self.num_examples = len(self.file_name)
+
+ def get(self, idx):
+ name = self.file_name[idx]
+ h5_file = os.path.join(self.preprocess_dir, name)
+ atomic_data = _abacus_h5_reader(h5_file, self.AtomicData_options)
+ return atomic_data
+
+ def len(self) -> int:
+ return self.num_examples
+
+# In memory version.
+
+class ABACUSInMemoryDataset(AtomicInMemoryDataset):
+
+ def __init__(
+ self,
+ root: str,
+ preprocess_dir: str,
+ url: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ include_frames: Optional[List[int]] = None,
+ type_mapper: TypeMapper = None,
+ ):
+ self.preprocess_dir = preprocess_dir
+ self.file_name = np.loadtxt(os.path.join(self.preprocess_dir, 'AtomicData_file.txt'), dtype=str)
+
+ super(ABACUSInMemoryDataset, self).__init__(
+ file_name=self.file_name,
+ url=url,
+ root=root,
+ AtomicData_options=AtomicData_options,
+ include_frames=include_frames,
+ type_mapper=type_mapper,
+ )
+
+ def get_data(self):
+ data = []
+ for name in self.file_name:
+ h5_file = os.path.join(self.preprocess_dir, name)
+ data.append(_abacus_h5_reader(h5_file, self.AtomicData_options))
+ return data
+
+ @property
+ def raw_file_names(self):
+ return "AtomicData.h5"
+
+ @property
+ def raw_dir(self):
+ return self.root
\ No newline at end of file
diff --git a/dptb/data/dataset/_abacus_dataset_mem.py b/dptb/data/dataset/_abacus_dataset_mem.py
new file mode 100644
index 00000000..de4263ae
--- /dev/null
+++ b/dptb/data/dataset/_abacus_dataset_mem.py
@@ -0,0 +1,109 @@
+from typing import Dict, Any, List, Callable, Union, Optional
+import os
+
+import numpy as np
+import h5py
+
+import torch
+
+from .. import (
+ AtomicData,
+ AtomicDataDict,
+)
+from ..transforms import TypeMapper, OrbitalMapper
+from ._base_datasets import AtomicInMemoryDataset
+from dptb.nn.hamiltonian import E3Hamiltonian
+from dptb.data.interfaces.ham_to_feature import ham_block_to_feature
+from dptb.data.interfaces.abacus import recursive_parse
+
+orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"}
+
+def _abacus_h5_reader(h5file_path, AtomicData_options):
+ data = h5py.File(h5file_path, "r")
+ atomic_data = AtomicData.from_points(
+ pos = data["pos"][:],
+ cell = data["cell"][:],
+ atomic_numbers = data["atomic_numbers"][:],
+ **AtomicData_options,
+ )
+ if data["hamiltonian_blocks"]:
+ basis = {}
+ for key, value in data["basis"].items():
+ basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)]
+ idp = OrbitalMapper(basis)
+ # e3 = E3Hamiltonian(idp=idp, decompose=True)
+ ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
+ # with torch.no_grad():
+ # atomic_data = e3(atomic_data.to_dict())
+ # atomic_data = AtomicData.from_dict(atomic_data)
+
+ if data.get("eigenvalue") and data.get("kpoint"):
+ atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoint"][:], dtype=torch.get_default_dtype())
+ atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalue"][:], dtype=torch.get_default_dtype())
+ return atomic_data
+
+
+class ABACUSInMemoryDataset(AtomicInMemoryDataset):
+
+ def __init__(
+ self,
+ root: str,
+ abacus_args: Dict[str, Union[str,bool]] = {
+ "input_dir": None,
+ "preprocess_dir": None,
+ "only_overlap": False,
+ "get_Ham": False,
+ "add_overlap": False,
+ "get_eigenvalues": False,
+ },
+ file_name: Optional[str] = None,
+ url: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ include_frames: Optional[List[int]] = None,
+ type_mapper: TypeMapper = None,
+ key_mapping: Dict[str, str] = {
+ "pos": AtomicDataDict.POSITIONS_KEY,
+ "energy": AtomicDataDict.TOTAL_ENERGY_KEY,
+ "atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ "kpoints": AtomicDataDict.KPOINT_KEY,
+ "eigenvalues": AtomicDataDict.ENERGY_EIGENVALUE_KEY,
+ },
+ ):
+ if file_name is not None:
+ self.file_name = file_name
+ else:
+ self.abacus_args = abacus_args
+ assert self.abacus_args.get("input_dir") is not None, "ABACUS calculation results MUST be provided."
+ if self.abacus_args.get("preprocess_dir") is None:
+ print("Creating new preprocess dictionary...")
+ os.mkdir(os.path.join(root, "preprocess"))
+ self.abacus_args["preprocess_dir"] = os.path.join(root, "preprocess")
+ self.key_mapping = key_mapping
+
+ print("Begin parsing ABACUS output...")
+ h5_filenames = recursive_parse(**self.abacus_args)
+ self.file_name = h5_filenames
+ print("Finished parsing ABACUS output.")
+
+ super(ABACUSInMemoryDataset, self).__init__(
+ file_name=self.file_name,
+ url=url,
+ root=root,
+ AtomicData_options=AtomicData_options,
+ include_frames=include_frames,
+ type_mapper=type_mapper,
+ )
+
+ def get_data(self):
+ data = []
+ for h5_file in self.file_name:
+ data.append(_abacus_h5_reader(h5_file, self.AtomicData_options))
+ return data
+
+ @property
+ def raw_file_names(self):
+ return "AtomicData.h5"
+
+ @property
+ def raw_dir(self):
+ return self.root
\ No newline at end of file
diff --git a/dptb/data/dataset/_ase_dataset.py b/dptb/data/dataset/_ase_dataset.py
new file mode 100644
index 00000000..6200ebe5
--- /dev/null
+++ b/dptb/data/dataset/_ase_dataset.py
@@ -0,0 +1,238 @@
+import tempfile
+import functools
+import itertools
+from os.path import dirname, basename, abspath
+from typing import Dict, Any, List, Union, Optional, Sequence
+
+import ase
+import ase.io
+
+import torch
+import torch.multiprocessing as mp
+
+
+from dptb.utils.multiprocessing import num_tasks
+from .. import AtomicData
+from ..transforms import TypeMapper
+from ._base_datasets import AtomicInMemoryDataset
+
+
+def _ase_dataset_reader(
+ rank: int,
+ world_size: int,
+ tmpdir: str,
+ ase_kwargs: dict,
+ atomicdata_kwargs: dict,
+ include_frames,
+ global_options: dict,
+) -> Union[str, List[AtomicData]]:
+ """Parallel reader for all frames in file."""
+ if world_size > 1:
+ from nequip.utils._global_options import _set_global_options
+
+ # ^ avoid import loop
+ # we only `multiprocessing` if world_size > 1
+ _set_global_options(global_options)
+ # interleave--- in theory it is better for performance for the ranks
+ # to read consecutive blocks, but the way ASE is written the whole
+ # file gets streamed through all ranks anyway, so just trust the OS
+ # to cache things sanely, which it will.
+ # ASE handles correctly the case where there are no frames in index
+ # and just gives an empty list, so that will succeed:
+ index = slice(rank, None, world_size)
+ if include_frames is None:
+ # count includes 0, 1, ..., inf
+ include_frames = itertools.count()
+
+ datas = []
+ # stream them from ase too using iread
+ for i, atoms in enumerate(ase.io.iread(**ase_kwargs, index=index, parallel=False)):
+ global_index = rank + (world_size * i)
+ datas.append(
+ (
+ global_index,
+ AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
+ if global_index in include_frames
+ # in-memory dataset will ignore this later, but needed for indexing to work out
+ else None,
+ )
+ )
+ # Save to a tempfile---
+ # there can be a _lot_ of tensors here, and rather than dealing with
+ # the complications of running out of file descriptors and setting
+ # sharing methods, since this is a one time thing, just make it simple
+ # and avoid shared memory entirely.
+ if world_size > 1:
+ path = f"{tmpdir}/rank{rank}.pth"
+ torch.save(datas, path)
+ return path
+ else:
+ return datas
+
+
+class ASEDataset(AtomicInMemoryDataset):
+ """
+
+ Args:
+ ase_args (dict): arguments for ase.io.read
+ include_keys (list): in addition to forces and energy, the keys that needs to
+ be parsed into dataset
+ The data stored in ase.atoms.Atoms.array has the lowest priority,
+ and it will be overrided by data in ase.atoms.Atoms.info
+ and ase.atoms.Atoms.calc.results. Optional
+ key_mapping (dict): rename some of the keys to the value str. Optional
+
+ Example: Given an atomic data stored in "H2.extxyz" that looks like below:
+
+ ```H2.extxyz
+ 2
+ Properties=species:S:1:pos:R:3 energy=-10 user_label=2.0 pbc="F F F"
+ H 0.00000000 0.00000000 0.00000000
+ H 0.00000000 0.00000000 1.02000000
+ ```
+
+ The yaml input should be
+
+ ```
+ dataset: ase
+ dataset_file_name: H2.extxyz
+ ase_args:
+ format: extxyz
+ include_keys:
+ - user_label
+ key_mapping:
+ user_label: label0
+ chemical_symbols:
+ - H
+ ```
+
+ for VASP parser, the yaml input should be
+ ```
+ dataset: ase
+ dataset_file_name: OUTCAR
+ ase_args:
+ format: vasp-out
+ key_mapping:
+ free_energy: total_energy
+ chemical_symbols:
+ - H
+ ```
+
+ """
+
+ def __init__(
+ self,
+ root: str,
+ ase_args: dict = {},
+ file_name: Optional[str] = None,
+ url: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ include_frames: Optional[List[int]] = None,
+ type_mapper: TypeMapper = None,
+ key_mapping: Optional[dict] = None,
+ include_keys: Optional[List[str]] = None,
+ ):
+ self.ase_args = {}
+ self.ase_args.update(getattr(type(self), "ASE_ARGS", dict()))
+ self.ase_args.update(ase_args)
+ assert "index" not in self.ase_args
+ assert "filename" not in self.ase_args
+
+ self.include_keys = include_keys
+ self.key_mapping = key_mapping
+
+ super().__init__(
+ file_name=file_name,
+ url=url,
+ root=root,
+ AtomicData_options=AtomicData_options,
+ include_frames=include_frames,
+ type_mapper=type_mapper,
+ )
+
+ @classmethod
+ def from_atoms_list(cls, atoms: Sequence[ase.Atoms], **kwargs):
+ """Make an ``ASEDataset`` from a list of ``ase.Atoms`` objects.
+
+ If `root` is not provided, a temporary directory will be used.
+
+ Please note that this is a convinience method that does NOT avoid a round-trip to disk; the provided ``atoms`` will be written out to a file.
+
+ Ignores ``kwargs["file_name"]`` if it is provided.
+
+ Args:
+ atoms
+ **kwargs: passed through to the constructor
+ Returns:
+ The constructed ``ASEDataset``.
+ """
+ if "root" not in kwargs:
+ tmpdir = tempfile.TemporaryDirectory()
+ kwargs["root"] = tmpdir.name
+ else:
+ tmpdir = None
+ kwargs["file_name"] = tmpdir.name + "/atoms.xyz"
+ atoms = list(atoms)
+ # Write them out
+ ase.io.write(kwargs["file_name"], atoms, format="extxyz")
+ # Read them in
+ obj = cls(**kwargs)
+ if tmpdir is not None:
+ # Make it keep a reference to the tmpdir to keep it alive
+ # When the dataset is garbage collected, the tmpdir will
+ # be too, and will (hopefully) get deleted eventually.
+ # Or at least by end of program...
+ obj._tmpdir_ref = tmpdir
+ return obj
+
+ @property
+ def raw_file_names(self):
+ return [basename(self.file_name)]
+
+ @property
+ def raw_dir(self):
+ return dirname(abspath(self.file_name))
+
+ def get_data(self):
+ ase_args = {"filename": self.raw_dir + "/" + self.raw_file_names[0]}
+ ase_args.update(self.ase_args)
+
+ # skip the None arguments
+ kwargs = dict(
+ include_keys=self.include_keys,
+ key_mapping=self.key_mapping,
+ )
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ kwargs.update(self.AtomicData_options)
+ n_proc = num_tasks()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ from nequip.utils._global_options import _get_latest_global_options
+
+ # ^ avoid import loop
+ reader = functools.partial(
+ _ase_dataset_reader,
+ world_size=n_proc,
+ tmpdir=tmpdir,
+ ase_kwargs=ase_args,
+ atomicdata_kwargs=kwargs,
+ include_frames=self.include_frames,
+ # get the global options of the parent to initialize the worker correctly
+ global_options=_get_latest_global_options(),
+ )
+ if n_proc > 1:
+ # things hang for some obscure OpenMP reason on some systems when using `fork` method
+ ctx = mp.get_context("forkserver")
+ with ctx.Pool(processes=n_proc) as p:
+ # map it over the `rank` argument
+ datas = p.map(reader, list(range(n_proc)))
+ # clean up the pool before loading the data
+ datas = [torch.load(d) for d in datas]
+ datas = sum(datas, [])
+ # un-interleave the datas
+ datas = sorted(datas, key=lambda e: e[0])
+ else:
+ datas = reader(rank=0)
+ # datas here is already in order, stride 1 start 0
+ # no need to un-interleave
+ # return list of AtomicData:
+ return [e[1] for e in datas]
diff --git a/dptb/data/dataset/_base_datasets.py b/dptb/data/dataset/_base_datasets.py
new file mode 100644
index 00000000..8c43c1c0
--- /dev/null
+++ b/dptb/data/dataset/_base_datasets.py
@@ -0,0 +1,665 @@
+import numpy as np
+import logging
+import inspect
+import itertools
+import yaml
+import hashlib
+import math
+from typing import Tuple, Dict, Any, List, Callable, Union, Optional
+
+import torch
+
+from torch_runstats.scatter import scatter_std, scatter_mean
+
+from dptb.utils.torch_geometric import Batch, Dataset
+from dptb.utils.tools import download_url, extract_zip
+
+import dptb
+from dptb.data import (
+ AtomicData,
+ AtomicDataDict,
+ _NODE_FIELDS,
+ _EDGE_FIELDS,
+ _GRAPH_FIELDS,
+)
+from dptb.utils.batch_ops import bincount
+from dptb.utils.regressor import solver
+from dptb.utils.savenload import atomic_write
+from ..transforms import TypeMapper
+
+
+class AtomicDataset(Dataset):
+ """The base class for all NequIP datasets."""
+
+ root: str
+ dtype: torch.dtype
+
+ def __init__(
+ self,
+ root: str,
+ type_mapper: Optional[TypeMapper] = None,
+ ):
+ self.dtype = torch.get_default_dtype()
+ super().__init__(root=root, transform=type_mapper)
+
+ def statistics(
+ self,
+ fields: List[Union[str, Callable]],
+ modes: List[str],
+ stride: int = 1,
+ unbiased: bool = True,
+ kwargs: Optional[Dict[str, dict]] = {},
+ ) -> List[tuple]:
+ # TODO: If needed, this can eventually be implimented for general AtomicDataset by computing an online running mean and using Welford's method for a stable running standard deviation: https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/
+ # That would be needed if we have lazy loading datasets.
+ # TODO: When lazy-loading datasets are implimented, how to deal with statistics, sampling, and subsets?
+ raise NotImplementedError("not implimented for general AtomicDataset yet")
+
+ @property
+ def type_mapper(self) -> Optional[TypeMapper]:
+ # self.transform is always a TypeMapper
+ return self.transform
+
+ def _get_parameters(self) -> Dict[str, Any]:
+ """Get a dict of the parameters used to build this dataset."""
+ pnames = list(inspect.signature(self.__init__).parameters)
+ IGNORE_KEYS = {
+ # the type mapper is applied after saving, not before, so doesn't matter to cache validity
+ "type_mapper"
+ }
+ params = {
+ k: getattr(self, k)
+ for k in pnames
+ if k not in IGNORE_KEYS and hasattr(self, k)
+ }
+ # Add other relevant metadata:
+ params["dtype"] = str(self.dtype)
+ params["nequip_version"] = dptb.__version__
+
+ return params
+
+ @property
+ def processed_dir(self) -> str:
+ # We want the file name to change when the parameters change
+ # So, first we get all parameters:
+ params = self._get_parameters()
+ # Make some kind of string of them:
+ # we don't care about this possibly changing between python versions,
+ # since a change in python version almost certainly means a change in
+ # versions of other things too, and is a good reason to recompute
+ buffer = yaml.dump(params).encode("ascii")
+ # And hash it:
+ param_hash = hashlib.sha1(buffer).hexdigest()
+ return f"{self.root}/processed_dataset_{param_hash}"
+
+
+class AtomicInMemoryDataset(AtomicDataset):
+ r"""Base class for all datasets that fit in memory.
+
+ Please note that, as a ``pytorch_geometric`` dataset, it must be backed by some kind of disk storage.
+ By default, the raw file will be stored at root/raw and the processed torch
+ file will be at root/process.
+
+ Subclasses must implement:
+ - ``raw_file_names``
+ - ``get_data()``
+
+ Subclasses may implement:
+ - ``download()`` or ``self.url`` or ``ClassName.URL``
+
+ Args:
+ root (str, optional): Root directory where the dataset should be saved. Defaults to current working directory.
+ file_name (str, optional): file name of data source. only used in children class
+ url (str, optional): url to download data source
+ AtomicData_options (dict, optional): extra key that are not stored in data but needed for AtomicData initialization
+ include_frames (list, optional): the frames to process with the constructor.
+ type_mapper (TypeMapper): the transformation to map atomic information to species index. Optional
+ """
+
+ def __init__(
+ self,
+ root: str,
+ file_name: Optional[str] = None,
+ url: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ include_frames: Optional[List[int]] = None,
+ type_mapper: Optional[TypeMapper] = None,
+ ):
+ # TO DO, this may be simplified
+ # See if a subclass defines some inputs
+ self.file_name = (
+ getattr(type(self), "FILE_NAME", None) if file_name is None else file_name
+ )
+ self.url = getattr(type(self), "URL", url)
+
+ self.AtomicData_options = AtomicData_options
+ self.include_frames = include_frames
+
+ self.data = None
+
+ # !!! don't delete this block.
+ # otherwise the inherent children class
+ # will ignore the download function here
+ class_type = type(self)
+ if class_type != AtomicInMemoryDataset:
+ if "download" not in self.__class__.__dict__:
+ class_type.download = AtomicInMemoryDataset.download
+ if "process" not in self.__class__.__dict__:
+ class_type.process = AtomicInMemoryDataset.process
+
+ # Initialize the InMemoryDataset, which runs download and process
+ # See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
+ # Then pre-process the data if disk files are not found
+ super().__init__(root=root, type_mapper=type_mapper)
+ if self.data is None:
+ self.data, include_frames = torch.load(self.processed_paths[0])
+ if not np.all(include_frames == self.include_frames):
+ raise ValueError(
+ f"the include_frames is changed. "
+ f"please delete the processed folder and rerun {self.processed_paths[0]}"
+ )
+
+ def len(self):
+ if self.data is None:
+ return 0
+ return self.data.num_graphs
+
+ @property
+ def raw_file_names(self):
+ raise NotImplementedError()
+
+ @property
+ def processed_file_names(self) -> List[str]:
+ return ["data.pth", "params.yaml"]
+
+ def get_data(
+ self,
+ ) -> Union[Tuple[Dict[str, Any], Dict[str, Any]], List[AtomicData]]:
+ """Get the data --- called from ``process()``, can assume that ``raw_file_names()`` exist.
+
+ Note that parameters for graph construction such as ``pbc`` and ``r_max`` should be included here as (likely, but not necessarily, fixed) fields.
+
+ Returns:
+ A dict:
+ fields: dict
+ mapping a field name ('pos', 'cell') to a list-like sequence of tensor-like objects giving that field's value for each example.
+ Or:
+ data_list: List[AtomicData]
+ """
+ raise NotImplementedError
+
+ def download(self):
+ if (not hasattr(self, "url")) or (self.url is None):
+ # Don't download, assume present. Later could have FileNotFound if the files don't actually exist
+ pass
+ else:
+ download_path = download_url(self.url, self.raw_dir)
+ if download_path.endswith(".zip"):
+ extract_zip(download_path, self.raw_dir)
+
+ def process(self):
+ data = self.get_data() ## get data returns either a list of AtomicData class or a data dict
+ if isinstance(data, list):
+
+ # It's a data list
+ data_list = data
+ if not (self.include_frames is None or data_list is None):
+ data_list = [data_list[i] for i in self.include_frames] # 可以选择数据集中加载的序号
+ assert all(isinstance(e, AtomicData) for e in data_list)
+ assert all(AtomicDataDict.BATCH_KEY not in e for e in data_list)
+
+ fields = {}
+
+ elif isinstance(data, dict):
+ # It's fields
+ # Get our data
+ fields = data
+
+ # check keys
+ all_keys = set(fields.keys())
+ assert AtomicDataDict.BATCH_KEY not in all_keys
+ # Check bad key combinations, but don't require that this be a graph yet.
+ AtomicDataDict.validate_keys(all_keys, graph_required=False)
+
+ # check dimesionality
+ num_examples = set([len(a) for a in fields.values()])
+ if not len(num_examples) == 1:
+ raise ValueError(
+ f"This dataset is invalid: expected all fields to have same length (same number of examples), but they had shapes { {f: v.shape for f, v in fields.items() } }"
+ )
+ num_examples = next(iter(num_examples))
+
+ include_frames = self.include_frames
+ if include_frames is None:
+ include_frames = range(num_examples)
+
+ # Make AtomicData from it:
+ if AtomicDataDict.EDGE_INDEX_KEY in all_keys:
+ # This is already a graph, just build it
+ constructor = AtomicData
+ else:
+ # do neighborlist from points
+ constructor = AtomicData.from_points
+ assert "r_max" in self.AtomicData_options
+ assert AtomicDataDict.POSITIONS_KEY in all_keys
+
+ data_list = [
+ constructor(
+ **{
+ **{f: v[i] for f, v in fields.items()},
+ **self.AtomicData_options,
+ }
+ )
+ for i in include_frames
+ ]
+
+
+ else:
+ raise ValueError("Invalid return from `self.get_data()`")
+
+ # Batch it for efficient saving
+ # This limits an AtomicInMemoryDataset to a maximum of LONG_MAX atoms _overall_, but that is a very big number and any dataset that large is probably not "InMemory" anyway
+ data = Batch.from_data_list(data_list)
+ del data_list
+ del fields
+
+ total_MBs = sum(item.numel() * item.element_size() for _, item in data) / (
+ 1024 * 1024
+ )
+ logging.info(
+ f"Loaded data: {data}\n processed data size: ~{total_MBs:.2f} MB"
+ )
+ del total_MBs
+
+ # use atomic writes to avoid race conditions between
+ # different trainings that use the same dataset
+ # since those separate trainings should all produce the same results,
+ # it doesn't matter if they overwrite each others cached'
+ # datasets. It only matters that they don't simultaneously try
+ # to write the _same_ file, corrupting it.
+ with atomic_write(self.processed_paths[0], binary=True) as f:
+ torch.save((data, self.include_frames), f)
+ with atomic_write(self.processed_paths[1], binary=False) as f:
+ yaml.dump(self._get_parameters(), f)
+
+ logging.info("Cached processed data to disk")
+
+ self.data = data
+
+ def get(self, idx):
+ return self.data.get_example(idx)
+
+ def _selectors(
+ self,
+ stride: int = 1,
+ ):
+ if self._indices is not None:
+ graph_selector = torch.as_tensor(self._indices)[::stride]
+ # note that self._indices is _not_ necessarily in order,
+ # while self.data --- which we take our arrays from ---
+ # is always in the original order.
+ # In particular, the values of `self.data.batch`
+ # are indexes in the ORIGINAL order
+ # thus we need graph level properties to also be in the original order
+ # so that batch values index into them correctly
+ # since self.data.batch is always sorted & contiguous
+ # (because of Batch.from_data_list)
+ # we sort it:
+ graph_selector, _ = torch.sort(graph_selector)
+ else:
+ graph_selector = torch.arange(0, self.len(), stride)
+
+ node_selector = torch.as_tensor(
+ np.in1d(self.data.batch.numpy(), graph_selector.numpy())
+ )
+
+ edge_index = self.data[AtomicDataDict.EDGE_INDEX_KEY]
+ edge_selector = node_selector[edge_index[0]] & node_selector[edge_index[1]]
+
+ return (graph_selector, node_selector, edge_selector)
+
+ def statistics(
+ self,
+ fields: List[Union[str, Callable]],
+ modes: List[str],
+ stride: int = 1,
+ unbiased: bool = True,
+ kwargs: Optional[Dict[str, dict]] = {},
+ ) -> List[tuple]:
+ """Compute the statistics of ``fields`` in the dataset.
+
+ If the values at the fields are vectors/multidimensional, they must be of fixed shape and elementwise statistics will be computed.
+
+ Args:
+ fields: the names of the fields to compute statistics for.
+ Instead of a field name, a callable can also be given that reuturns a quantity to compute the statisics for.
+
+ If a callable is given, it will be called with a (possibly batched) ``Data``-like object and must return a sequence of points to add to the set over which the statistics will be computed.
+ The callable must also return a string, one of ``"node"`` or ``"graph"``, indicating whether the value it returns is a per-node or per-graph quantity.
+ PLEASE NOTE: the argument to the callable may be "batched", and it may not be batched "contiguously": ``batch`` and ``edge_index`` may have "gaps" in their values.
+
+ For example, to compute the overall statistics of the x,y, and z components of a per-node vector ``force`` field:
+
+ data.statistics([lambda data: (data.force.flatten(), "node")])
+
+ The above computes the statistics over a set of size 3N, where N is the total number of nodes in the dataset.
+
+ modes: the statistic to compute for each field. Valid options are:
+ - ``count``
+ - ``rms``
+ - ``mean_std``
+ - ``per_atom_*``
+ - ``per_species_*``
+
+ stride: the stride over the dataset while computing statistcs.
+
+ unbiased: whether to use unbiased for standard deviations.
+
+ kwargs: other options for individual statistics modes.
+
+ Returns:
+ List of statistics. For fields of floating dtype the statistics are the two-tuple (mean, std); for fields of integer dtype the statistics are a one-tuple (bincounts,)
+ """
+
+ # Short circut:
+ assert len(modes) == len(fields)
+ if len(fields) == 0:
+ return []
+
+ graph_selector, node_selector, edge_selector = self._selectors(stride=stride)
+
+ num_graphs = len(graph_selector)
+ num_nodes = node_selector.sum()
+ num_edges = edge_selector.sum()
+
+ if self.transform is not None:
+ # pre-transform the data so that statistics process transformed data
+ data_transformed = self.transform(self.data.to_dict(), types_required=False)
+ else:
+ data_transformed = self.data.to_dict()
+ # pre-select arrays
+ # this ensures that all following computations use the right data
+ all_keys = set()
+ selectors = {}
+ for k in data_transformed.keys():
+ all_keys.add(k)
+ if k in _NODE_FIELDS:
+ selectors[k] = node_selector
+ elif k in _GRAPH_FIELDS:
+ selectors[k] = graph_selector
+ elif k == AtomicDataDict.EDGE_INDEX_KEY:
+ selectors[k] = (slice(None, None, None), edge_selector)
+ elif k in _EDGE_FIELDS:
+ selectors[k] = edge_selector
+ # TODO: do the batch indexes, edge_indexes, etc. after selection need to be
+ # "compacted" to subtract out their offsets? For now, we just punt this
+ # onto the writer of the callable field.
+ # apply selector to actual data
+ data_transformed = {
+ k: data_transformed[k][selectors[k]]
+ for k in data_transformed.keys()
+ if k in selectors
+ }
+
+ atom_types: Optional[torch.Tensor] = None
+ out: list = []
+ for ifield, field in enumerate(fields):
+ if callable(field):
+ # make a joined thing? so it includes fixed fields
+ arr, arr_is_per = field(data_transformed)
+ arr = arr.to(self.dtype) # all statistics must be on floating
+ assert arr_is_per in ("node", "graph", "edge")
+ else:
+ if field not in all_keys:
+ raise RuntimeError(
+ f"The field key `{field}` is not present in this dataset"
+ )
+ if field not in selectors:
+ # this means field is not selected and so not available
+ raise RuntimeError(
+ f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such using `nequip.data.register_fields`"
+ )
+ arr = data_transformed[field]
+ if field in _NODE_FIELDS:
+ arr_is_per = "node"
+ elif field in _GRAPH_FIELDS:
+ arr_is_per = "graph"
+ elif field in _EDGE_FIELDS:
+ arr_is_per = "edge"
+ else:
+ raise RuntimeError
+
+ # Check arr
+ if arr is None:
+ raise ValueError(
+ f"Cannot compute statistics over field `{field}` whose value is None!"
+ )
+ if not isinstance(arr, torch.Tensor):
+ if np.issubdtype(arr.dtype, np.floating):
+ arr = torch.as_tensor(arr, dtype=self.dtype)
+ else:
+ arr = torch.as_tensor(arr)
+ if arr_is_per == "node":
+ arr = arr.view(num_nodes, -1)
+ elif arr_is_per == "graph":
+ arr = arr.view(num_graphs, -1)
+ elif arr_is_per == "edge":
+ arr = arr.view(num_edges, -1)
+
+ ana_mode = modes[ifield]
+ # compute statistics
+ if ana_mode == "count":
+ # count integers
+ uniq, counts = torch.unique(
+ torch.flatten(arr), return_counts=True, sorted=True
+ )
+ out.append((uniq, counts))
+ elif ana_mode == "rms":
+ # root-mean-square
+ out.append((torch.sqrt(torch.mean(arr * arr)),))
+
+ elif ana_mode == "mean_std":
+ # mean and std
+ if len(arr) < 2:
+ raise ValueError(
+ "Can't do per species standard deviation without at least two samples"
+ )
+ mean = torch.mean(arr, dim=0)
+ std = torch.std(arr, dim=0, unbiased=unbiased)
+ out.append((mean, std))
+
+ elif ana_mode == "absmax":
+ out.append((arr.abs().max(),))
+
+ elif ana_mode.startswith("per_species_"):
+ # per-species
+ algorithm_kwargs = kwargs.pop(field + ana_mode, {})
+
+ ana_mode = ana_mode[len("per_species_") :]
+
+ if atom_types is None:
+ atom_types = data_transformed[AtomicDataDict.ATOM_TYPE_KEY]
+
+ results = self._per_species_statistics(
+ ana_mode,
+ arr,
+ arr_is_per=arr_is_per,
+ batch=data_transformed[AtomicDataDict.BATCH_KEY],
+ atom_types=atom_types,
+ unbiased=unbiased,
+ algorithm_kwargs=algorithm_kwargs,
+ )
+ out.append(results)
+
+ elif ana_mode.startswith("per_atom_"):
+ # per-atom
+ # only makes sense for a per-graph quantity
+ if arr_is_per != "graph":
+ raise ValueError(
+ f"It doesn't make sense to ask for `{ana_mode}` since `{field}` is not per-graph"
+ )
+ ana_mode = ana_mode[len("per_atom_") :]
+ results = self._per_atom_statistics(
+ ana_mode=ana_mode,
+ arr=arr,
+ batch=data_transformed[AtomicDataDict.BATCH_KEY],
+ unbiased=unbiased,
+ )
+ out.append(results)
+
+ else:
+ raise NotImplementedError(f"Cannot handle statistics mode {ana_mode}")
+
+ return out
+
+ @staticmethod
+ def _per_atom_statistics(
+ ana_mode: str,
+ arr: torch.Tensor,
+ batch: torch.Tensor,
+ unbiased: bool = True,
+ ):
+ """Compute "per-atom" statistics that are normalized by the number of atoms in the system.
+
+ Only makes sense for a graph-level quantity (checked by .statistics).
+ """
+ # using unique_consecutive handles the non-contiguous selected batch index
+ _, N = torch.unique_consecutive(batch, return_counts=True)
+ N = N.unsqueeze(-1)
+ assert N.ndim == 2
+ assert N.shape == (len(arr), 1)
+ assert arr.ndim >= 2
+ data_dim = arr.shape[1:]
+ arr = arr / N
+ assert arr.shape == (len(N),) + data_dim
+ if ana_mode == "mean_std":
+ if len(arr) < 2:
+ raise ValueError(
+ "Can't do standard deviation without at least two samples"
+ )
+ mean = torch.mean(arr, dim=0)
+ std = torch.std(arr, unbiased=unbiased, dim=0)
+ return mean, std
+ elif ana_mode == "rms":
+ return (torch.sqrt(torch.mean(arr.square())),)
+ elif ana_mode == "absmax":
+ return (torch.max(arr.abs()),)
+ else:
+ raise NotImplementedError(
+ f"{ana_mode} for per-atom analysis is not implemented"
+ )
+
+ def _per_species_statistics(
+ self,
+ ana_mode: str,
+ arr: torch.Tensor,
+ arr_is_per: str,
+ atom_types: torch.Tensor,
+ batch: torch.Tensor,
+ unbiased: bool = True,
+ algorithm_kwargs: Optional[dict] = {},
+ ):
+ """Compute "per-species" statistics.
+
+ For a graph-level quantity, models it as a linear combintation of the number of atoms of different types in the graph.
+
+ For a per-node quantity, computes the expected statistic but for each type instead of over all nodes.
+ """
+ N = bincount(atom_types.squeeze(-1), batch)
+ assert N.ndim == 2 # [batch, n_type]
+ N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes
+ assert arr.ndim >= 2
+ if arr_is_per == "graph":
+
+ if ana_mode != "mean_std":
+ raise NotImplementedError(
+ f"{ana_mode} for per species analysis is not implemented for shape {arr.shape}"
+ )
+
+ N = N.type(self.dtype)
+
+ return solver(N, arr, **algorithm_kwargs)
+
+ elif arr_is_per == "node":
+ arr = arr.type(self.dtype)
+
+ if ana_mode == "mean_std":
+ # There need to be at least two occurances of each atom type in the
+ # WHOLE dataset, not in any given frame:
+ if torch.any(N.sum(dim=0) < 2):
+ raise ValueError(
+ "Can't do per species standard deviation without at least two samples per species"
+ )
+ mean = scatter_mean(arr, atom_types, dim=0)
+ assert mean.shape[1:] == arr.shape[1:] # [N, dims] -> [type, dims]
+ assert len(mean) == N.shape[1]
+ std = scatter_std(arr, atom_types, dim=0, unbiased=unbiased)
+ assert std.shape == mean.shape
+ return mean, std
+ elif ana_mode == "rms":
+ square = scatter_mean(arr.square(), atom_types, dim=0)
+ assert square.shape[1:] == arr.shape[1:] # [N, dims] -> [type, dims]
+ assert len(square) == N.shape[1]
+ dims = len(square.shape) - 1
+ for i in range(dims):
+ square = square.mean(axis=-1)
+ return (torch.sqrt(square),)
+ else:
+ raise NotImplementedError(
+ f"Statistics mode {ana_mode} isn't yet implemented for per_species_"
+ )
+
+ else:
+ raise NotImplementedError
+
+ def rdf(
+ self, bin_width: float, stride: int = 1
+ ) -> Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]:
+ """Compute the pairwise RDFs of the dataset.
+
+ Args:
+ bin_width: width of the histogram bin in distance units
+ stride: stride of data to include
+
+ Returns:
+ dictionary mapping `(type1, type2)` to tuples of `(hist, bin_edges)` in the style of `np.histogram`.
+ """
+ graph_selector, node_selector, edge_selector = self._selectors(stride=stride)
+
+ data = AtomicData.to_AtomicDataDict(self.data)
+ data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)
+
+ results = {}
+
+ types = self.type_mapper(data)[AtomicDataDict.ATOM_TYPE_KEY]
+
+ edge_types = torch.index_select(
+ types, 0, data[AtomicDataDict.EDGE_INDEX_KEY].reshape(-1)
+ ).view(2, -1)
+ types_center = edge_types[0].numpy()
+ types_neigh = edge_types[1].numpy()
+
+ r_max: float = self.AtomicData_options["r_max"]
+ # + 1 to always have a zero bin at the end
+ n_bins: int = int(math.ceil(r_max / bin_width)) + 1
+ # +1 since these are bin_edges including rightmost
+ bins = bin_width * np.arange(n_bins + 1)
+
+ for type1, type2 in itertools.combinations_with_replacement(
+ range(self.type_mapper.num_types), 2
+ ):
+ # Try to do as much of this as possible in-place
+ mask = types_center == type1
+ np.logical_and(mask, types_neigh == type2, out=mask)
+ np.logical_and(mask, edge_selector, out=mask)
+ mask = mask.astype(np.int32)
+ results[(type1, type2)] = np.histogram(
+ data[AtomicDataDict.EDGE_LENGTH_KEY],
+ weights=mask,
+ bins=bins,
+ density=True,
+ )
+ # RDF is symmetric
+ results[(type2, type1)] = results[(type1, type2)]
+
+ return results
diff --git a/dptb/data/dataset/_deeph_dataset.py b/dptb/data/dataset/_deeph_dataset.py
new file mode 100644
index 00000000..2485ce1b
--- /dev/null
+++ b/dptb/data/dataset/_deeph_dataset.py
@@ -0,0 +1,79 @@
+from typing import Dict, Any, List, Callable, Union, Optional
+import os
+import numpy as np
+import h5py
+
+import torch
+
+from .. import (
+ AtomicData,
+ AtomicDataDict,
+)
+from ..transforms import TypeMapper, OrbitalMapper
+from ._base_datasets import AtomicDataset
+from dptb.nn.hamiltonian import E3Hamiltonian
+from dptb.data.interfaces.ham_to_feature import openmx_to_deeptb
+
+orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"}
+
+class DeePHE3Dataset(AtomicDataset):
+
+ def __init__(
+ self,
+ root: str,
+ key_mapping: Dict[str, str] = {
+ "pos": AtomicDataDict.POSITIONS_KEY,
+ "energy": AtomicDataDict.TOTAL_ENERGY_KEY,
+ "atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ "kpoints": AtomicDataDict.KPOINT_KEY,
+ "eigenvalues": AtomicDataDict.ENERGY_EIGENVALUE_KEY,
+ },
+ preprocess_path: str = None,
+ subdir_names: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ type_mapper: Optional[TypeMapper] = None,
+ ):
+ super().__init__(root=root, type_mapper=type_mapper)
+ self.key_mapping = key_mapping
+ self.key_list = list(key_mapping.keys())
+ self.value_list = list(key_mapping.values())
+ self.subdir_names = subdir_names
+ self.preprocess_path = preprocess_path
+
+ self.AtomicData_options = AtomicData_options
+ # self.r_max = AtomicData_options["r_max"]
+ # self.er_max = AtomicData_options["er_max"]
+ # self.oer_max = AtomicData_options["oer_max"]
+ # self.pbc = AtomicData_options["pbc"]
+
+ self.index = None
+ self.num_examples = len(subdir_names)
+
+ def get(self, idx):
+ file_name = self.subdir_names[idx]
+ file = os.path.join(self.preprocess_path, file_name)
+
+ if os.path.exists(os.path.join(file, "AtomicData.pth")):
+ atomic_data = torch.load(os.path.join(file, "AtomicData.pth"))
+ else:
+ atomic_data = AtomicData.from_points(
+ pos = np.loadtxt(os.path.join(file, "site_positions.dat")).T,
+ cell = np.loadtxt(os.path.join(file, "lat.dat")).T,
+ atomic_numbers = np.loadtxt(os.path.join(file, "element.dat")),
+ **self.AtomicData_options,
+ )
+
+ idp = self.type_mapper
+ # e3 = E3Hamiltonian(idp=idp, decompose=True)
+
+ openmx_to_deeptb(atomic_data, idp, os.path.join(file, "./hamiltonians.h5"))
+ # with torch.no_grad():
+ # atomic_data = e3(atomic_data.to_dict())
+ # atomic_data = AtomicData.from_dict(atomic_data)
+
+ torch.save(atomic_data, os.path.join(file, "AtomicData.pth"))
+
+ return atomic_data
+
+ def len(self) -> int:
+ return self.num_examples
\ No newline at end of file
diff --git a/dptb/data/dataset/_default_dataset.py b/dptb/data/dataset/_default_dataset.py
new file mode 100644
index 00000000..4a0d7f30
--- /dev/null
+++ b/dptb/data/dataset/_default_dataset.py
@@ -0,0 +1,425 @@
+from typing import Dict, Any, List, Callable, Union, Optional
+import os
+import glob
+
+import numpy as np
+import h5py
+from ase import Atoms
+from ase.io import Trajectory
+
+import torch
+
+from .. import (
+ AtomicData,
+ AtomicDataDict,
+)
+from ..transforms import TypeMapper, OrbitalMapper
+from ._base_datasets import AtomicDataset, AtomicInMemoryDataset
+#from dptb.nn.hamiltonian import E3Hamiltonian
+from dptb.data.interfaces.ham_to_feature import ham_block_to_feature
+from dptb.utils.tools import j_loader
+from dptb.data.AtomicDataDict import with_edge_vectors
+from dptb.nn.hamiltonian import E3Hamiltonian
+
+class _TrajData(object):
+ '''
+ Input files format in a trajectory (shape):
+ "info.json": optional, includes infomation in the data files.
+ can be provided in the base (upper level) folder, or assign in each trajectory.
+ "cell.dat": fixed cell (3, 3) or variable cells (nframes, 3, 3). Unit: Angstrom
+ "atomic_numbers.dat": (natoms) or (nframes, natoms)
+ "positions.dat": concentrate all positions in one file, (nframes * natoms, 3). Can be cart or frac.
+
+ Optional data files:
+ "eigenvalues.npy": concentrate all engenvalues in one file, (nframes, nkpoints, nbands)
+ "kpoints.npy": MUST be provided when loading `eigenvalues.npy`, (nkpoints, 3) or (nframes, nkpints, 3)
+ "hamiltonians.h5": h5 file storing atom-wise hamiltonian blocks labeled by frames id and `i0_jR_Rx_Ry_Rz`.
+ "overlaps.h5": the same format of overlap blocks as `hamiltonians.h5`
+ '''
+
+ def __init__(self,
+ root: str,
+ AtomicData_options: Dict[str, Any] = {},
+ get_Hamiltonian = False,
+ get_eigenvalues = False,
+ info = None,
+ _clear = False):
+ self.root = root
+ self.AtomicData_options = AtomicData_options
+ self.info = info
+
+ self.data = {}
+ # load cell
+ cell = np.loadtxt(os.path.join(root, "cell.dat"))
+ if cell.shape[0] == 3:
+ # same cell size, then copy it to all frames.
+ cell = np.expand_dims(cell, axis=0)
+ self.data["cell"] = np.broadcast_to(cell, (self.info["nframes"], 3, 3))
+ elif cell.shape[0] == self.info["nframes"] * 3:
+ self.data["cell"] = cell.reshape(self.info["nframes"], 3, 3)
+ else:
+ raise ValueError("Wrong cell dimensions.")
+
+ # load atomic numbers
+ atomic_numbers = np.loadtxt(os.path.join(root, "atomic_numbers.dat"))
+ if atomic_numbers.shape[0] == self.info["natoms"]:
+ # same atomic_numbers, copy it to all frames.
+ atomic_numbers = np.expand_dims(atomic_numbers, axis=0)
+ self.data["atomic_numbers"] = np.broadcast_to(atomic_numbers, (self.info["nframes"],
+ self.info["natoms"]))
+ elif atomic_numbers.shape[0] == self.info["natoms"] * self.info["nframes"]:
+ self.data["atomic_numbers"] = atomic_numbers.reshape(self.info["nframes"],
+ self.info["natoms"])
+ else:
+ raise ValueError("Wrong atomic_number dimensions.")
+
+ # load positions, stored as cartesion no matter what provided.
+ pos = np.loadtxt(os.path.join(root, "positions.dat"))
+ assert pos.shape[0] == self.info["nframes"] * self.info["natoms"]
+ pos = pos.reshape(self.info["nframes"], self.info["natoms"], 3)
+ # ase use cartesian by default.
+ if self.info["pos_type"] == "cart" or self.info["pos_type"] == "ase":
+ self.data["pos"] = pos
+ elif self.info["pos_type"] == "frac":
+ self.data["pos"] = pos @ self.data["cell"]
+ else:
+ raise NameError("Position type must be cart / frac.")
+
+ # load optional data files
+ if os.path.exists(os.path.join(self.root, "eigenvalues.npy")) and get_eigenvalues==True:
+ assert "bandinfo" in self.info, "`bandinfo` must be provided in `info.json` for loading eigenvalues."
+ assert os.path.exists(os.path.join(self.root, "kpoints.npy"))
+ kpoints = np.load(os.path.join(self.root, "kpoints.npy"))
+ if kpoints.ndim == 2:
+ # only one frame or same kpoints, then copy it to all frames.
+ # shape: (nkpoints, 3)
+ kpoints = np.expand_dims(kpoints, axis=0)
+ self.data["kpoints"] = np.broadcast_to(kpoints, (self.info["nframes"],
+ kpoints.shape[1], 3))
+ elif kpoints.ndim == 3 and kpoints.shape[0] == self.info["nframes"]:
+ # array of kpoints, (nframes, nkpoints, 3)
+ self.data["kpoints"] = kpoints
+ else:
+ raise ValueError("Wrong kpoint dimensions.")
+ eigenvalues = np.load(os.path.join(self.root, "eigenvalues.npy"))
+ # special case: trajectory contains only one frame
+ if eigenvalues.ndim == 2:
+ eigenvalues = np.expand_dims(eigenvalues, axis=0)
+ assert eigenvalues.shape[0] == self.info["nframes"]
+ assert eigenvalues.shape[1] == self.data["kpoints"].shape[1]
+ self.data["eigenvalues"] = eigenvalues
+ if os.path.exists(os.path.join(self.root, "hamiltonians.h5")) and get_Hamiltonian==True:
+ self.data["hamiltonian_blocks"] = h5py.File(os.path.join(self.root, "hamiltonians.h5"), "r")
+ if os.path.exists(os.path.join(self.root, "overlaps.h5")):
+ self.data["overlap_blocks"] = h5py.File(os.path.join(self.root, "overlaps.h5"), "r")
+
+ # this is used to clear the tmp files to load ase trajectory only.
+ if _clear:
+ os.remove(os.path.join(root, "positions.dat"))
+ os.remove(os.path.join(root, "cell.dat"))
+ os.remove(os.path.join(root, "atomic_numbers.dat"))
+
+ @classmethod
+ def from_ase_traj(cls,
+ root: str,
+ AtomicData_options: Dict[str, Any] = {},
+ get_Hamiltonian = False,
+ get_eigenvalues = False,
+ info = None):
+
+ traj_file = glob.glob(f"{root}/*.traj")
+ assert len(traj_file) == 1, print("only one ase trajectory file can be provided.")
+ traj = Trajectory(traj_file[0], 'r')
+ positions = []
+ cell = []
+ atomic_numbers = []
+ for atoms in traj:
+ positions.append(atoms.get_positions())
+ cell.append(atoms.get_cell())
+ atomic_numbers.append(atoms.get_atomic_numbers())
+ positions = np.array(positions)
+ positions = positions.reshape(-1, 3)
+ cell = np.array(cell)
+ cell = cell.reshape(-1, 3)
+ atomic_numbers = np.array(atomic_numbers)
+ atomic_numbers = atomic_numbers.reshape(-1, 1)
+ np.savetxt(os.path.join(root, "positions.dat"), positions)
+ np.savetxt(os.path.join(root, "cell.dat"), cell)
+ np.savetxt(os.path.join(root, "atomic_numbers.dat"), atomic_numbers, fmt='%d')
+
+ return cls(root=root,
+ AtomicData_options=AtomicData_options,
+ get_Hamiltonian=get_Hamiltonian,
+ get_eigenvalues=get_eigenvalues,
+ info=info,
+ _clear=True)
+
+ def toAtomicDataList(self, idp: TypeMapper = None):
+ data_list = []
+ for frame in range(self.info["nframes"]):
+ atomic_data = AtomicData.from_points(
+ pos = self.data["pos"][frame][:],
+ cell = self.data["cell"][frame][:],
+ atomic_numbers = self.data["atomic_numbers"][frame],
+ # pbc is stored in AtomicData_options now.
+ #pbc = self.info["pbc"],
+ **self.AtomicData_options)
+ if "hamiltonian_blocks" in self.data:
+ assert idp is not None, "LCAO Basis must be provided in `common_option` for loading Hamiltonian."
+ if "overlap_blocks" not in self.data:
+ self.data["overlap_blocks"] = [False] * self.info["nframes"]
+ # e3 = E3Hamiltonian(idp=idp, decompose=True)
+ ham_block_to_feature(atomic_data, idp,
+ self.data["hamiltonian_blocks"][str(frame+1)],
+ self.data["overlap_blocks"][str(frame+1)])
+
+ # TODO: initialize the edge and node feature tempretely, there should be a better way.
+ else:
+ # just temporarily initialize the edge and node feature to zeros, to let the batch collate work.
+ atomic_data[AtomicDataDict.EDGE_FEATURES_KEY] = torch.zeros(atomic_data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], 1)
+ atomic_data[AtomicDataDict.NODE_FEATURES_KEY] = torch.zeros(atomic_data[AtomicDataDict.POSITIONS_KEY].shape[0], 1)
+ atomic_data[AtomicDataDict.EDGE_OVERLAP_KEY] = torch.zeros(atomic_data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], 1)
+ # with torch.no_grad():
+ # atomic_data = e3(atomic_data.to_dict())
+ # atomic_data = AtomicData.from_dict(atomic_data)
+ if "eigenvalues" in self.data and "kpoints" in self.data:
+ assert "bandinfo" in self.info, "`bandinfo` must be provided in `info.json` for loading eigenvalues."
+ bandinfo = self.info["bandinfo"]
+ atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(self.data["kpoints"][frame][:],
+ dtype=torch.get_default_dtype())
+ if bandinfo["emin"] is not None and bandinfo["emax"] is not None:
+ atomic_data[AtomicDataDict.ENERGY_WINDOWS_KEY] = torch.as_tensor([bandinfo["emin"], bandinfo["emax"]],
+ dtype=torch.get_default_dtype())
+ if bandinfo["band_min"] is not None and bandinfo["band_max"] is not None:
+ atomic_data[AtomicDataDict.BAND_WINDOW_KEY] = torch.as_tensor([bandinfo["band_min"], bandinfo["band_max"]],
+ dtype=torch.long)
+ # atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(self.data["eigenvalues"][frame][:, bandinfo["band_min"]:bandinfo["band_max"]],
+ # dtype=torch.get_default_dtype())
+ atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(self.data["eigenvalues"][frame],
+ dtype=torch.get_default_dtype())
+ data_list.append(atomic_data)
+ return data_list
+
+
+class DefaultDataset(AtomicInMemoryDataset):
+
+ def __init__(
+ self,
+ root: str,
+ info_files: Dict[str, Dict],
+ url: Optional[str] = None, # seems useless but can't be remove
+ include_frames: Optional[List[int]] = None, # maybe support in future
+ type_mapper: TypeMapper = None,
+ get_Hamiltonian: bool = False,
+ get_eigenvalues: bool = False,
+ ):
+ self.root = root
+ self.url = url
+ self.info_files = info_files
+ # The following flags are stored to label dataset.
+ self.get_Hamiltonian = get_Hamiltonian
+ self.get_eigenvalues = get_eigenvalues
+
+ # load all data files
+ self.raw_data = []
+ for file in self.info_files.keys():
+ # get the info here
+ info = info_files[file]
+ assert "AtomicData_options" in info
+ AtomicData_options = info["AtomicData_options"]
+ assert "r_max" in AtomicData_options
+ assert "pbc" in AtomicData_options
+ if info["pos_type"] == "ase":
+ subdata = _TrajData.from_ase_traj(os.path.join(self.root, file),
+ AtomicData_options,
+ get_Hamiltonian,
+ get_eigenvalues,
+ info=info)
+ else:
+ subdata = _TrajData(os.path.join(self.root, file),
+ AtomicData_options,
+ get_Hamiltonian,
+ get_eigenvalues,
+ info=info)
+ self.raw_data.append(subdata)
+
+ # The AtomicData_options is never used here.
+ # Because we always return a list of AtomicData object in `get_data()`.
+ # That is, AtomicInMemoryDataset will not use AtomicData_options to build any AtomicData here.
+ super().__init__(
+ file_name=None, # this seems not important too.
+ url=url,
+ root=root,
+ AtomicData_options={}, # we do not pass anything here.
+ include_frames=include_frames,
+ type_mapper=type_mapper,
+ )
+
+ def get_data(self):
+ all_data = []
+ for subdata in self.raw_data:
+ # the type_mapper here is loaded in PyG `dataset` type as `transform` attritube
+ # so the OrbitalMapper can be accessed by self.transform here
+ subdata_list = subdata.toAtomicDataList(self.transform)
+ all_data += subdata_list
+ return all_data
+
+ @property
+ def raw_file_names(self):
+ # TODO: this is not implemented.
+ return "Null"
+
+ @property
+ def raw_dir(self):
+ # TODO: this is not implemented.
+ return self.root
+
+ def E3statistics(self, decay=False):
+ assert self.transform is not None
+ idp = self.transform
+
+ if self.data[AtomicDataDict.EDGE_FEATURES_KEY].abs().sum() < 1e-7:
+ return None
+
+ typed_dataset = idp(self.data.clone().to_dict())
+ e3h = E3Hamiltonian(basis=idp.basis, decompose=True)
+ with torch.no_grad():
+ typed_dataset = e3h(typed_dataset)
+
+ stats = {}
+ stats["node"] = self._E3nodespecies_stat(typed_dataset=typed_dataset)
+ stats["edge"] = self._E3edgespecies_stat(typed_dataset=typed_dataset, decay=decay)
+
+ return stats
+
+ def _E3edgespecies_stat(self, typed_dataset, decay):
+ # we get the bond type marked dataset first
+ idp = self.transform
+ typed_dataset = typed_dataset
+
+ idp.get_irreps(no_parity=False)
+ irrep_slices = idp.orbpair_irreps.slices()
+
+ features = typed_dataset["edge_features"]
+ hopping_block_mask = idp.mask_to_erme[typed_dataset["edge_type"].flatten()]
+ typed_hopping = {}
+ for bt, tp in idp.bond_to_type.items():
+ hopping_tp_mask = hopping_block_mask[typed_dataset["edge_type"].flatten().eq(tp)]
+ hopping_tp = features[typed_dataset["edge_type"].flatten().eq(tp)]
+ filtered_vecs = torch.where(hopping_tp_mask, hopping_tp, torch.tensor(float('nan')))
+ typed_hopping[bt] = filtered_vecs
+
+ sorted_irreps = idp.orbpair_irreps.sort()[0].simplify()
+ n_scalar = sorted_irreps[0].mul if sorted_irreps[0].ir.l == 0 else 0
+
+ # calculate norm & mean
+ typed_norm = {}
+ typed_norm_ave = torch.ones(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps)
+ typed_norm_std = torch.zeros(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps)
+ typed_scalar_ave = torch.ones(len(idp.bond_to_type), n_scalar)
+ typed_scalar_std = torch.zeros(len(idp.bond_to_type), n_scalar)
+ for bt, tp in idp.bond_to_type.items():
+ norms_per_irrep = []
+ count_scalar = 0
+ for ir, s in enumerate(irrep_slices):
+ sub_tensor = typed_hopping[bt][:, s]
+ # dump the nan blocks here
+ if sub_tensor.shape[-1] == 1:
+ count_scalar += 1
+ if not torch.isnan(sub_tensor).all():
+ # update the mean and ave
+ norms = torch.norm(sub_tensor, p=2, dim=1) # shape: [n_edge]
+ if sub_tensor.shape[-1] == 1:
+ # it's a scalar
+ typed_scalar_ave[tp][count_scalar-1] = sub_tensor.mean()
+ typed_scalar_std[tp][count_scalar-1] = sub_tensor.std()
+ typed_norm_ave[tp][ir] = norms.mean()
+ typed_norm_std[tp][ir] = norms.std()
+ else:
+ norms = torch.ones_like(sub_tensor[:, 0])
+
+ if decay:
+ norms_per_irrep.append(norms)
+
+ assert count_scalar <= n_scalar
+ # shape of typed_norm: (n_irreps, n_edges)
+
+ if decay:
+ typed_norm[bt] = torch.stack(norms_per_irrep)
+
+ edge_stats = {
+ "norm_ave": typed_norm_ave,
+ "norm_std": typed_norm_std,
+ "scalar_ave": typed_scalar_ave,
+ "scalar_std": typed_scalar_std,
+ }
+
+ if decay:
+ typed_dataset = with_edge_vectors(typed_dataset)
+ decay = {}
+ for bt, tp in idp.bond_to_type.items():
+ decay_bt = {}
+ lengths_bt = typed_dataset["edge_lengths"][typed_dataset["edge_type"].flatten().eq(tp)]
+ sorted_lengths, indices = lengths_bt.sort() # from small to large
+ # sort the norms by irrep l
+ sorted_norms = typed_norm[bt][idp.orbpair_irreps.sort().inv, :]
+ # sort the norms by edge length
+ sorted_norms = sorted_norms[:, indices]
+ decay_bt["edge_length"] = sorted_lengths
+ decay_bt["norm_decay"] = sorted_norms
+ decay[bt] = decay_bt
+
+ edge_stats["decay"] = decay
+
+ return edge_stats
+
+ def _E3nodespecies_stat(self, typed_dataset):
+ # we get the type marked dataset first
+ idp = self.transform
+ typed_dataset = typed_dataset
+
+ idp.get_irreps(no_parity=False)
+ irrep_slices = idp.orbpair_irreps.slices()
+
+ sorted_irreps = idp.orbpair_irreps.sort()[0].simplify()
+ n_scalar = sorted_irreps[0].mul if sorted_irreps[0].ir.l == 0 else 0
+
+ features = typed_dataset["node_features"]
+ onsite_block_mask = idp.mask_to_nrme[typed_dataset["atom_types"].flatten()]
+ typed_onsite = {}
+ for at, tp in idp.chemical_symbol_to_type.items():
+ onsite_tp_mask = onsite_block_mask[typed_dataset["atom_types"].flatten().eq(tp)]
+ onsite_tp = features[typed_dataset["atom_types"].flatten().eq(tp)]
+ filtered_vecs = torch.where(onsite_tp_mask, onsite_tp, torch.tensor(float('nan')))
+ typed_onsite[at] = filtered_vecs
+
+ # calculate norm & mean
+ typed_norm_ave = torch.ones(len(idp.chemical_symbol_to_type), idp.orbpair_irreps.num_irreps)
+ typed_norm_std = torch.zeros(len(idp.chemical_symbol_to_type), idp.orbpair_irreps.num_irreps)
+ typed_scalar_ave = torch.ones(len(idp.chemical_symbol_to_type), n_scalar)
+ typed_scalar_std = torch.zeros(len(idp.chemical_symbol_to_type), n_scalar)
+ for at, tp in idp.chemical_symbol_to_type.items():
+ count_scalar = 0
+ for ir, s in enumerate(irrep_slices):
+ sub_tensor = typed_onsite[at][:, s]
+ # dump the nan blocks here
+ if sub_tensor.shape[-1] == 1:
+ count_scalar += 1
+ if not torch.isnan(sub_tensor).all():
+
+ norms = torch.norm(sub_tensor, p=2, dim=1)
+ typed_norm_ave[tp][ir] = norms.mean()
+ typed_norm_std[tp][ir] = norms.std()
+ if s.stop - s.start == 1:
+ # it's a scalar
+ typed_scalar_ave[tp][count_scalar-1] = sub_tensor.mean()
+ typed_scalar_std[tp][count_scalar-1] = sub_tensor.std()
+
+ edge_stats = {
+ "norm_ave": typed_norm_ave,
+ "norm_std": typed_norm_std,
+ "scalar_ave": typed_scalar_ave,
+ "scalar_std": typed_scalar_std,
+ }
+
+ return edge_stats
\ No newline at end of file
diff --git a/dptb/data/dataset/_hdf5_dataset.py b/dptb/data/dataset/_hdf5_dataset.py
new file mode 100644
index 00000000..5fce39e2
--- /dev/null
+++ b/dptb/data/dataset/_hdf5_dataset.py
@@ -0,0 +1,171 @@
+from typing import Dict, Any, List, Callable, Union, Optional
+from collections import defaultdict
+import numpy as np
+
+import torch
+
+from .. import (
+ AtomicData,
+ AtomicDataDict,
+)
+from ..transforms import TypeMapper
+from ._base_datasets import AtomicDataset
+
+
+class HDF5Dataset(AtomicDataset):
+ """A dataset that loads data from a HDF5 file.
+
+ This class is useful for very large datasets that cannot fit in memory. It
+ efficiently loads data from disk as needed without everything needing to be
+ in memory at once.
+
+ To use this, ``file_name`` should point to the HDF5 file, or alternatively a
+ semicolon separated list of multiple files. Each group in the file contains
+ samples that all have the same number of atoms. Typically there is one
+ group for each unique number of atoms, but that is not required. Each group
+ should contain arrays whose length equals the number of samples, one for each
+ type of data. The names of the arrays can be specified with ``key_mapping``.
+
+ Args:
+ key_mapping (Dict[str, str]): mapping of array names in the HDF5 file to ``AtomicData`` keys
+ file_name (string): a semicolon separated list of HDF5 files.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ key_mapping: Dict[str, str] = {
+ "pos": AtomicDataDict.POSITIONS_KEY,
+ "energy": AtomicDataDict.TOTAL_ENERGY_KEY,
+ "forces": AtomicDataDict.FORCE_KEY,
+ "atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ "types": AtomicDataDict.ATOM_TYPE_KEY,
+ },
+ file_name: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ type_mapper: Optional[TypeMapper] = None,
+ ):
+ super().__init__(root=root, type_mapper=type_mapper)
+ self.key_mapping = key_mapping
+ self.key_list = list(key_mapping.keys())
+ self.value_list = list(key_mapping.values())
+ self.file_name = file_name
+ self.r_max = AtomicData_options["r_max"]
+ self.index = None
+ self.num_frames = 0
+ import h5py
+
+ files = [h5py.File(f, "r") for f in self.file_name.split(";")]
+ for file in files:
+ for group_name in file:
+ for key in self.key_list:
+ if key in file[group_name]:
+ self.num_frames += len(file[group_name][key])
+ break
+ file.close()
+
+ def setup_index(self):
+ import h5py
+
+ files = [h5py.File(f, "r") for f in self.file_name.split(";")]
+ self.has_forces = False
+ self.index = []
+ for file in files:
+ for group_name in file:
+ group = file[group_name]
+ values = [None] * len(self.key_list)
+ samples = 0
+ for i, key in enumerate(self.key_list):
+ if key in group:
+ values[i] = group[key]
+ samples = len(values[i])
+ for i in range(samples):
+ self.index.append(tuple(values + [i]))
+
+ def len(self) -> int:
+ return self.num_frames
+
+ def get(self, idx: int) -> AtomicData:
+ if self.index is None:
+ self.setup_index()
+ data = self.index[idx]
+ i = data[-1]
+ args = {"r_max": self.r_max}
+ for j, value in enumerate(self.value_list):
+ if data[j] is not None:
+ args[value] = data[j][i]
+ return AtomicData.from_points(**args)
+
+ def statistics(
+ self,
+ fields: List[Union[str, Callable]],
+ modes: List[str],
+ stride: int = 1,
+ unbiased: bool = True,
+ kwargs: Optional[Dict[str, dict]] = {},
+ ) -> List[tuple]:
+ assert len(modes) == len(fields)
+ # TODO: use RunningStats
+ if len(fields) == 0:
+ return []
+ if self.index is None:
+ self.setup_index()
+ results = []
+ indices = self.indices()
+ if stride != 1:
+ indices = list(indices)[::stride]
+ for field, mode in zip(fields, modes):
+ count = 0
+ if mode == "rms":
+ total = 0.0
+ elif mode in ("mean_std", "per_atom_mean_std"):
+ total = [0.0, 0.0]
+ elif mode == "count":
+ counts = defaultdict(int)
+ else:
+ raise NotImplementedError(f"Analysis mode '{mode}' is not implemented")
+ for index in indices:
+ data = self.index[index]
+ i = data[-1]
+ if field in self.value_list:
+ values = data[self.value_list.index(field)][i]
+ elif callable(field):
+ values, _ = field(self.get(index))
+ values = np.asarray(values)
+ else:
+ raise RuntimeError(
+ f"The field key `{field}` is not present in this dataset"
+ )
+ length = len(values.flatten())
+ if length == 1:
+ values = np.array([values])
+ if mode == "rms":
+ total += np.sum(values * values)
+ count += length
+ elif mode == "count":
+ for v in values:
+ counts[v] += 1
+ else:
+ if mode == "per_atom_mean_std":
+ values /= len(data[0][i])
+ for v in values:
+ count += 1
+ delta1 = v - total[0]
+ total[0] += delta1 / count
+ delta2 = v - total[0]
+ total[1] += delta1 * delta2
+ if mode == "rms":
+ results.append(torch.tensor((np.sqrt(total / count),)))
+ elif mode == "count":
+ values = sorted(counts.keys())
+ results.append(
+ (torch.tensor(values), torch.tensor([counts[v] for v in values]))
+ )
+ else:
+ results.append(
+ (
+ torch.tensor(total[0]),
+ torch.tensor(np.sqrt(total[1] / (count - 1))),
+ )
+ )
+ return results
diff --git a/dptb/data/dataset/_npz_dataset.py b/dptb/data/dataset/_npz_dataset.py
new file mode 100644
index 00000000..84080604
--- /dev/null
+++ b/dptb/data/dataset/_npz_dataset.py
@@ -0,0 +1,143 @@
+import numpy as np
+from os.path import dirname, basename, abspath
+from typing import Dict, Any, List, Optional
+
+
+from .. import AtomicDataDict, _LONG_FIELDS, _NODE_FIELDS, _GRAPH_FIELDS
+from ..transforms import TypeMapper
+from ._base_datasets import AtomicInMemoryDataset
+
+
+class NpzDataset(AtomicInMemoryDataset):
+ """Load data from an npz file.
+
+ To avoid loading unneeded data, keys are ignored by default unless they are in ``key_mapping``, ``include_keys``,
+ or ``npz_fixed_fields_keys``.
+
+ Args:
+ key_mapping (Dict[str, str]): mapping of npz keys to ``AtomicData`` keys. Optional
+ include_keys (list): the attributes to be processed and stored. Optional
+ npz_fixed_field_keys: the attributes that only have one instance but apply to all frames. Optional
+ Note that the mapped keys (as determined by the _values_ in ``key_mapping``) should be used in
+ ``npz_fixed_field_keys``, not the original npz keys from before mapping. If an npz key is not
+ present in ``key_mapping``, it is mapped to itself, and this point is not relevant.
+
+ Example: Given a npz file with 10 configurations, each with 14 atoms.
+
+ position: (10, 14, 3)
+ force: (10, 14, 3)
+ energy: (10,)
+ Z: (14)
+ user_label1: (10) # per config
+ user_label2: (10, 14, 3) # per atom
+
+ The input yaml should be
+
+ ```yaml
+ dataset: npz
+ dataset_file_name: example.npz
+ include_keys:
+ - user_label1
+ - user_label2
+ npz_fixed_field_keys:
+ - cell
+ - atomic_numbers
+ key_mapping:
+ position: pos
+ force: forces
+ energy: total_energy
+ Z: atomic_numbers
+ graph_fields:
+ - user_label1
+ node_fields:
+ - user_label2
+ ```
+
+ """
+
+ def __init__(
+ self,
+ root: str,
+ key_mapping: Dict[str, str] = {
+ "positions": AtomicDataDict.POSITIONS_KEY,
+ "energy": AtomicDataDict.TOTAL_ENERGY_KEY,
+ "force": AtomicDataDict.FORCE_KEY,
+ "forces": AtomicDataDict.FORCE_KEY,
+ "Z": AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ "atomic_number": AtomicDataDict.ATOMIC_NUMBERS_KEY,
+ },
+ include_keys: List[str] = [],
+ npz_fixed_field_keys: List[str] = [],
+ file_name: Optional[str] = None,
+ url: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ include_frames: Optional[List[int]] = None,
+ type_mapper: TypeMapper = None,
+ ):
+ self.key_mapping = key_mapping
+ self.npz_fixed_field_keys = npz_fixed_field_keys
+ self.include_keys = include_keys
+
+ super().__init__(
+ file_name=file_name,
+ url=url,
+ root=root,
+ AtomicData_options=AtomicData_options,
+ include_frames=include_frames,
+ type_mapper=type_mapper,
+ )
+
+ @property
+ def raw_file_names(self):
+ return [basename(self.file_name)]
+
+ @property
+ def raw_dir(self):
+ return dirname(abspath(self.file_name))
+
+ def get_data(self):
+ # get data returns either a list of AtomicData class or a data dict
+
+ data = np.load(self.raw_dir + "/" + self.raw_file_names[0], allow_pickle=True)
+
+ # only the keys explicitly mentioned in the yaml file will be parsed
+ keys = set(list(self.key_mapping.keys()))
+
+ keys.update(self.npz_fixed_field_keys)
+ keys.update(self.include_keys)
+ keys = keys.intersection(set(list(data.keys())))
+
+ mapped = {self.key_mapping.get(k, k): data[k] for k in keys}
+
+ for intkey in _LONG_FIELDS:
+ if intkey in mapped:
+ mapped[intkey] = mapped[intkey].astype(np.int64)
+
+ fields = {k: v for k, v in mapped.items() if k not in self.npz_fixed_field_keys}
+ num_examples, num_atoms, n_dim = fields[AtomicDataDict.POSITIONS_KEY].shape
+ assert n_dim == 3
+
+ # now we replicate and add the fixed fields:
+ for fixed_field in self.npz_fixed_field_keys:
+ orig = mapped[fixed_field]
+ if fixed_field in _NODE_FIELDS:
+ assert orig.ndim >= 1 # [n_atom, feature_dims]
+ assert orig.shape[0] == num_atoms
+ replicated = np.expand_dims(orig, 0)
+ replicated = np.tile(
+ replicated,
+ (num_examples,) + (1,) * len(replicated.shape[1:]),
+ ) # [n_example, n_atom, feature_dims]
+ elif fixed_field in _GRAPH_FIELDS:
+ # orig is [feature_dims]
+ replicated = np.expand_dims(orig, 0)
+ replicated = np.tile(
+ replicated,
+ (num_examples,) + (1,) * len(replicated.shape[1:]),
+ ) # [n_example, feature_dims]
+ else:
+ raise KeyError(
+ f"npz_fixed_field_keys contains `{fixed_field}`, but it isn't registered as a node or graph field"
+ )
+ fields[fixed_field] = replicated
+ return fields
diff --git a/dptb/nnet/__init__.py b/dptb/data/interfaces/__init__.py
similarity index 100%
rename from dptb/nnet/__init__.py
rename to dptb/data/interfaces/__init__.py
diff --git a/dptb/data/interfaces/abacus.py b/dptb/data/interfaces/abacus.py
new file mode 100644
index 00000000..6123ee5d
--- /dev/null
+++ b/dptb/data/interfaces/abacus.py
@@ -0,0 +1,378 @@
+# Modified from script 'abasus_get_data.py' for interface from ABACUS to DeepH-pack
+# To use this script, please add 'out_mat_hs2 1' in ABACUS INPUT File
+# Current version is capable of coping with f-orbitals
+
+import os
+import glob
+import json
+import re
+from collections import Counter
+from tqdm import tqdm
+
+import numpy as np
+from scipy.sparse import csr_matrix
+from scipy.linalg import block_diag
+import h5py
+import ase
+
+orbitalId = {0:'s',1:'p',2:'d',3:'f'}
+Bohr2Ang = 0.529177249
+
+
+class OrbAbacus2DeepTB:
+ def __init__(self):
+ self.Us_abacus2deeptb = {}
+ self.Us_abacus2deeptb[0] = np.eye(1)
+ self.Us_abacus2deeptb[1] = np.eye(3)[[2, 0, 1]] # 0, 1, -1 -> -1, 0, 1
+ self.Us_abacus2deeptb[2] = np.eye(5)[[4, 2, 0, 1, 3]] # 0, 1, -1, 2, -2 -> -2, -1, 0, 1, 2
+ self.Us_abacus2deeptb[3] = np.eye(7)[[6, 4, 2, 0, 1, 3, 5]] # -3,-2,-1,0,1,2,3
+
+ minus_dict = {
+ 1: [0, 2],
+ 2: [1, 3],
+ 3: [0, 2, 4, 6],
+ }
+
+ for k, v in minus_dict.items():
+ self.Us_abacus2deeptb[k][v] *= -1 # add phase (-1)^m
+
+ def get_U(self, l):
+ if l > 3:
+ raise NotImplementedError("Only support l = s, p, d, f")
+ return self.Us_abacus2deeptb[l]
+
+ def transform(self, mat, l_lefts, l_rights):
+ block_lefts = block_diag(*[self.get_U(l_left) for l_left in l_lefts])
+ block_rights = block_diag(*[self.get_U(l_right) for l_right in l_rights])
+ return block_lefts @ mat @ block_rights.T
+
+def recursive_parse(input_path,
+ preprocess_dir,
+ data_name="OUT.ABACUS",
+ only_overlap=False,
+ parse_Hamiltonian=False,
+ add_overlap=False,
+ parse_eigenvalues=False,
+ prefix="data"):
+ """
+ Parse ABACUS single point SCF calculation outputs.
+ Input:
+ `input_dir`: target dictionary(ies) containing "OUT.ABACUS" folder.
+ can be wildcard characters or a string list.
+ `preprocess_dir`: output dictionary of all processed data files.
+ `data_name`: output dictionary name of ABACUS, by default "OUT.ABACUS".
+ `only_overlap`: usually `False`.
+ set to `True` if the calculation job is getting overlap matrix ONLY.
+ `parse_Hamiltonian`: determine whether parsing the Hamiltonian `.csr` file or not.
+ `add_overlap`: determine whether parsing the overlap `.csr` file or not.
+ `parse_Hamiltonian` must be true to add overlap.
+ `parse_eigenvalues`: determine whether parsing `kpoints.dat` and `BAND_1.dat` or not.
+ that is, the k-points will always be loaded with bands.
+ `prefix`: prefix of the processed data folders' names.
+ """
+ if isinstance(input_path, list) and all(isinstance(item, str) for item in input_path):
+ input_path = input_path
+ else:
+ input_path = glob.glob(input_path)
+ preprocess_dir = os.path.abspath(preprocess_dir)
+ os.makedirs(preprocess_dir, exist_ok=True)
+ # h5file_names = []
+
+ folders = [item for item in input_path if os.path.isdir(item)]
+
+ with tqdm(total=len(folders)) as pbar:
+ for index, folder in enumerate(folders):
+ datafiles = os.listdir(folder)
+ if data_name in datafiles:
+ # The follwing `if` block is used by us only.
+ if os.path.exists(os.path.join(folder, data_name, "hscsr.tgz")):
+ os.system("cd "+os.path.join(folder, data_name) + " && tar -zxvf hscsr.tgz && mv OUT.ABACUS/* ./")
+ try:
+ _abacus_parse(folder,
+ os.path.join(preprocess_dir, f"{prefix}.{index}"),
+ data_name,
+ only_S=only_overlap,
+ get_Ham=parse_Hamiltonian,
+ add_overlap=add_overlap,
+ get_eigenvalues=parse_eigenvalues)
+ #h5file_names.append(os.path.join(file, "AtomicData.h5"))
+ pbar.update(1)
+ except Exception as e:
+ print(f"Error in {folder}/{data_name}: {e}")
+ continue
+ #return h5file_names
+
+def _abacus_parse(input_path,
+ output_path,
+ data_name,
+ only_S=False,
+ get_Ham=False,
+ add_overlap=False,
+ get_eigenvalues=False):
+
+ input_path = os.path.abspath(input_path)
+ output_path = os.path.abspath(output_path)
+ os.makedirs(output_path, exist_ok=True)
+
+ def find_target_line(f, target):
+ line = f.readline()
+ while line:
+ if target in line:
+ return line
+ line = f.readline()
+ return None
+ if only_S:
+ log_file_name = "running_get_S.log"
+ else:
+ log_file_name = "running_scf.log"
+
+ with open(os.path.join(input_path, data_name, log_file_name), 'r') as f_chk:
+ lines = f_chk.readlines()
+ if not lines or " Total Time :" not in lines[-1]:
+ raise ValueError(f"Job is not normal ending!")
+
+ with open(os.path.join(input_path, data_name, log_file_name), 'r') as f:
+ f.readline()
+ line = f.readline()
+ # assert "WELCOME TO ABACUS" in line
+ assert find_target_line(f, "READING UNITCELL INFORMATION") is not None, 'Cannot find "READING UNITCELL INFORMATION" in log file'
+ num_atom_type = int(f.readline().split()[-1])
+
+ assert find_target_line(f, "lattice constant (Bohr)") is not None
+ lattice_constant = float(f.readline().split()[-1]) # unit is Angstrom, didn't read (Bohr) here.
+
+ site_norbits_dict = {}
+ orbital_types_dict = {}
+ for index_type in range(num_atom_type):
+ tmp = find_target_line(f, "READING ATOM TYPE")
+ assert tmp is not None, 'Cannot find "ATOM TYPE" in log file'
+ assert tmp.split()[-1] == str(index_type + 1)
+ if tmp is None:
+ raise Exception(f"Cannot find ATOM {index_type} in {log_file_name}")
+
+ line = f.readline()
+ assert "atom label =" in line
+ atom_label = line.split()[-1]
+ assert atom_label in ase.data.atomic_numbers, "Atom label should be in periodic table"
+ atom_type = ase.data.atomic_numbers[atom_label]
+
+ current_site_norbits = 0
+ current_orbital_types = []
+ while True:
+ line = f.readline()
+ if "number of zeta" in line:
+ tmp = line.split()
+ L = int(tmp[0][2:-1])
+ num_L = int(tmp[-1])
+ current_site_norbits += (2 * L + 1) * num_L
+ current_orbital_types.extend([L] * num_L)
+ else:
+ break
+ site_norbits_dict[atom_type] = current_site_norbits
+ orbital_types_dict[atom_type] = current_orbital_types
+
+ #print(orbital_types_dict)
+
+ line = find_target_line(f, "TOTAL ATOM NUMBER")
+ assert line is not None, 'Cannot find "TOTAL ATOM NUMBER" in log file'
+ nsites = int(line.split()[-1])
+
+ line = find_target_line(f, " COORDINATES")
+ assert line is not None, 'Cannot find "DIRECT COORDINATES" or "CARTESIAN COORDINATES" in log file'
+ if "DIRECT" in line:
+ coords_type = "direct"
+ elif "CARTESIAN" in line:
+ coords_type = "cartesian"
+ else:
+ raise ValueError('Cannot find "DIRECT COORDINATES" or "CARTESIAN COORDINATES" in log file')
+
+ assert "atom" in f.readline()
+ frac_coords = np.zeros((nsites, 3))
+ site_norbits = np.zeros(nsites, dtype=int)
+ element = np.zeros(nsites, dtype=int)
+ for index_site in range(nsites):
+ line = f.readline()
+ tmp = line.split()
+ assert "tau" in tmp[0]
+ atom_label = ''.join(re.findall(r'[A-Za-z]', tmp[0][5:]))
+ assert atom_label in ase.data.atomic_numbers, "Atom label should be in periodic table"
+ element[index_site] = ase.data.atomic_numbers[atom_label]
+ site_norbits[index_site] = site_norbits_dict[element[index_site]]
+ frac_coords[index_site, :] = np.array(tmp[1:4])
+ norbits = int(np.sum(site_norbits))
+ site_norbits_cumsum = np.cumsum(site_norbits)
+
+ assert find_target_line(f, "Lattice vectors: (Cartesian coordinate: in unit of a_0)") is not None
+ lattice = np.zeros((3, 3))
+ for index_lat in range(3):
+ lattice[index_lat, :] = np.array(f.readline().split())
+ if coords_type == "cartesian":
+ frac_coords = frac_coords @ np.matrix(lattice).I # get frac_coords anyway
+ lattice = lattice * lattice_constant
+
+ if only_S:
+ spinful = False
+ else:
+ line = find_target_line(f, "NSPIN")
+ assert line is not None, 'Cannot find "NSPIN" in log file'
+ if "NSPIN == 1" in line:
+ spinful = False
+ elif "NSPIN == 4" in line:
+ spinful = True
+ else:
+ raise ValueError(f'{line} is not supported')
+
+ np.savetxt(os.path.join(output_path, "cell.dat"), lattice)
+ np.savetxt(os.path.join(output_path, "rcell.dat"), np.linalg.inv(lattice) * 2 * np.pi)
+ cart_coords = frac_coords @ lattice
+ np.savetxt(os.path.join(output_path, "positions.dat").format(output_path), cart_coords)
+ np.savetxt(os.path.join(output_path, "atomic_numbers.dat"), element, fmt='%d')
+ #info = {'nsites' : nsites, 'isorthogonal': False, 'isspinful': spinful, 'norbits': norbits}
+ #with open('{}/info.json'.format(output_path), 'w') as info_f:
+ # json.dump(info, info_f)
+ with open(os.path.join(output_path, "basis.dat"), 'w') as f:
+ for atomic_number in element:
+ counter = Counter(orbital_types_dict[atomic_number])
+ num_orbs = [counter[i] for i in range(4)] # s, p, d, f
+ for index_l, l in enumerate(num_orbs):
+ if l == 0: # no this orbit
+ continue
+ if index_l == 0:
+ f.write(f"{l}{orbitalId[index_l]}")
+ else:
+ f.write(f"{l}{orbitalId[index_l]}")
+ f.write('\n')
+ atomic_basis = {}
+ for atomic_number, orbitals in orbital_types_dict.items():
+ atomic_basis[ase.data.chemical_symbols[atomic_number]] = orbitals
+
+ if get_Ham:
+ U_orbital = OrbAbacus2DeepTB()
+ def parse_matrix(matrix_path, factor, spinful=False):
+ matrix_dict = dict()
+ with open(matrix_path, 'r') as f:
+ line = f.readline() # read "Matrix Dimension of ..."
+ if not "Matrix Dimension of" in line:
+ line = f.readline() # ABACUS >= 3.0
+ assert "Matrix Dimension of" in line
+ f.readline() # read "Matrix number of ..."
+ norbits = int(line.split()[-1])
+ for line in f:
+ line1 = line.split()
+ if len(line1) == 0:
+ break
+ num_element = int(line1[3])
+ if num_element != 0:
+ R_cur = np.array(line1[:3]).astype(int)
+ line2 = f.readline().split()
+ line3 = f.readline().split()
+ line4 = f.readline().split()
+ if not spinful:
+ hamiltonian_cur = csr_matrix((np.array(line2).astype(float), np.array(line3).astype(int),
+ np.array(line4).astype(int)), shape=(norbits, norbits)).toarray()
+ else:
+ line2 = np.char.replace(line2, '(', '')
+ line2 = np.char.replace(line2, ')', 'j')
+ line2 = np.char.replace(line2, ',', '+')
+ line2 = np.char.replace(line2, '+-', '-')
+ hamiltonian_cur = csr_matrix((np.array(line2).astype(np.complex128), np.array(line3).astype(int),
+ np.array(line4).astype(int)), shape=(norbits, norbits)).toarray()
+ for index_site_i in range(nsites):
+ for index_site_j in range(nsites):
+ key_str = f"{index_site_i + 1}_{index_site_j + 1}_{R_cur[0]}_{R_cur[1]}_{R_cur[2]}"
+ mat = hamiltonian_cur[(site_norbits_cumsum[index_site_i]
+ - site_norbits[index_site_i]) * (1 + spinful):
+ site_norbits_cumsum[index_site_i] * (1 + spinful),
+ (site_norbits_cumsum[index_site_j] - site_norbits[index_site_j]) * (1 + spinful):
+ site_norbits_cumsum[index_site_j] * (1 + spinful)]
+ if abs(mat).max() < 1e-10:
+ continue
+ if not spinful:
+ mat = U_orbital.transform(mat, orbital_types_dict[element[index_site_i]],
+ orbital_types_dict[element[index_site_j]])
+ else:
+ mat = mat.reshape((site_norbits[index_site_i], 2, site_norbits[index_site_j], 2))
+ mat = mat.transpose((1, 0, 3, 2)).reshape((2 * site_norbits[index_site_i],
+ 2 * site_norbits[index_site_j]))
+ mat = U_orbital.transform(mat, orbital_types_dict[element[index_site_i]] * 2,
+ orbital_types_dict[element[index_site_j]] * 2)
+ matrix_dict[key_str] = mat * factor
+ return matrix_dict, norbits
+
+ if only_S:
+ overlap_dict, tmp = parse_matrix(os.path.join(input_path, "SR.csr"), 1)
+ assert tmp == norbits
+ else:
+ hamiltonian_dict, tmp = parse_matrix(
+ os.path.join(input_path, data_name, "data-HR-sparse_SPIN0.csr"), 13.605698, # Ryd2eV
+ spinful=spinful)
+ assert tmp == norbits * (1 + spinful)
+ overlap_dict, tmp = parse_matrix(os.path.join(input_path, data_name, "data-SR-sparse_SPIN0.csr"), 1,
+ spinful=spinful)
+ assert tmp == norbits * (1 + spinful)
+ if spinful:
+ overlap_dict_spinless = {}
+ for k, v in overlap_dict.items():
+ overlap_dict_spinless[k] = v[:v.shape[0] // 2, :v.shape[1] // 2].real
+ overlap_dict_spinless, overlap_dict = overlap_dict, overlap_dict_spinless
+
+ if not only_S:
+ with h5py.File(os.path.join(output_path, "hamiltonians.h5"), 'w') as fid:
+ # creating a default group here adapting to the format used in DefaultDataset.
+ # by the way DefaultDataset loading h5 file, the index should be "1" here.
+ default_group = fid.create_group("1")
+ for key_str, value in hamiltonian_dict.items():
+ default_group[key_str] = value
+ if add_overlap:
+ with h5py.File(os.path.join(output_path, "overlaps.h5"), 'w') as fid:
+ default_group = fid.create_group("1")
+ for key_str, value in overlap_dict.items():
+ default_group[key_str] = value
+
+ if get_eigenvalues:
+ kpts = []
+ with open(os.path.join(input_path, data_name, "kpoints"), "r") as f:
+ nkstot = f.readline().strip().split()[-1]
+ f.readline()
+ for _ in range(int(nkstot)):
+ line = f.readline()
+ kpt = []
+ line = line.strip().split()
+ kpt.extend([float(line[1]), float(line[2]), float(line[3])])
+ kpts.append(kpt)
+ kpts = np.array(kpts)
+
+ with open(os.path.join(input_path, data_name, "BANDS_1.dat"), "r") as file:
+ band_lines = file.readlines()
+ band = []
+ for line in band_lines:
+ values = line.strip().split()
+ eigs = [float(value) for value in values[1:]]
+ band.append(eigs)
+ band = np.array(band)
+
+ assert len(band) == len(kpts)
+ np.save(os.path.join(output_path, "kpoints.npy"), kpts)
+ np.save(os.path.join(output_path, "eigenvalues.npy"), band)
+
+ #with h5py.File(os.path.join(output_path, "AtomicData.h5"), "w") as f:
+ # f["cell"] = lattice
+ # f["pos"] = cart_coords
+ # f["atomic_numbers"] = element
+ # basis = f.create_group("basis")
+ # for key, value in atomic_basis.items():
+ # basis[key] = value
+ # if get_Ham:
+ # f["hamiltonian_blocks"] = h5py.ExternalLink("hamiltonians.h5", "/")
+ # if add_overlap:
+ # f["overlap_blocks"] = h5py.ExternalLink("overlaps.h5", "/")
+ # # else:
+ # # f["overlap_blocks"] = False
+ # # else:
+ # # f["hamiltonian_blocks"] = False
+ # if get_eigenvalues:
+ # f["kpoints"] = kpts
+ # f["eigenvalues"] = band
+ # # else:
+ # # f["kpoint"] = False
+ # # f["eigenvalue"] = False
diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py
new file mode 100644
index 00000000..710f8b49
--- /dev/null
+++ b/dptb/data/interfaces/ham_to_feature.py
@@ -0,0 +1,209 @@
+from .. import _keys
+import ase
+import numpy as np
+import torch
+import re
+import e3nn.o3 as o3
+import h5py
+import logging
+from dptb.utils.constants import anglrMId
+
+log = logging.getLogger(__name__)
+
+def ham_block_to_feature(data, idp, Hamiltonian_blocks, overlap_blocks=False):
+ # Hamiltonian_blocks should be a h5 group in the current version
+ onsite_ham = []
+ edge_ham = []
+ if overlap_blocks:
+ edge_overlap = []
+
+ idp.get_orbital_maps()
+ idp.get_orbpair_maps()
+
+ atomic_numbers = data[_keys.ATOMIC_NUMBERS_KEY]
+
+ # onsite features
+ for atom in range(len(atomic_numbers)):
+ block_index = '_'.join(map(str, map(int, [atom+1, atom+1] + list([0, 0, 0]))))
+ try:
+ block = Hamiltonian_blocks[block_index]
+ except:
+ raise IndexError("Hamiltonian block for onsite not found, check Hamiltonian file.")
+
+ symbol = ase.data.chemical_symbols[atomic_numbers[atom]]
+ basis_list = idp.basis[symbol]
+ onsite_out = np.zeros(idp.reduced_matrix_element)
+
+ for index, basis_i in enumerate(basis_list):
+ slice_i = idp.orbital_maps[symbol][basis_i]
+ for basis_j in basis_list[index:]:
+ slice_j = idp.orbital_maps[symbol][basis_j]
+ block_ij = block[slice_i, slice_j]
+ full_basis_i = idp.basis_to_full_basis[symbol][basis_i]
+ full_basis_j = idp.basis_to_full_basis[symbol][basis_j]
+
+ # fill onsite vector
+ pair_ij = full_basis_i + "-" + full_basis_j
+ feature_slice = idp.orbpair_maps[pair_ij]
+ onsite_out[feature_slice] = block_ij.flatten()
+
+ onsite_ham.append(onsite_out)
+ #onsite_ham = np.array(onsite_ham)
+
+ # edge features
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY]
+
+ for atom_i, atom_j, R_shift in zip(edge_index[0], edge_index[1], edge_cell_shift):
+ block_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(R_shift))))
+ symbol_i = ase.data.chemical_symbols[atomic_numbers[atom_i]]
+ symbol_j = ase.data.chemical_symbols[atomic_numbers[atom_j]]
+
+ # try:
+ # block = Hamiltonian_blocks[block_index]
+ # if overlap_blocks:
+ # block_s = overlap_blocks[block_index]
+ # except:
+ # raise IndexError("Hamiltonian block for hopping not found, r_cut may be too big for input R.")
+
+ block = Hamiltonian_blocks.get(block_index, 0)
+ if overlap_blocks:
+ block_s = overlap_blocks.get(block_index, 0)
+ if block == 0:
+ block = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j])
+ log.warning("Hamiltonian block for hopping {} not found, r_cut may be too big for input R.".format(block_index))
+ if overlap_blocks:
+ if block_s == 0:
+ block_s = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j])
+ log.warning("Overlap block for hopping {} not found, r_cut may be too big for input R.".format(block_index))
+
+ assert block.shape == (idp.norbs[symbol_i], idp.norbs[symbol_j])
+
+
+ basis_i_list = idp.basis[symbol_i]
+ basis_j_list = idp.basis[symbol_j]
+ hopping_out = np.zeros(idp.reduced_matrix_element)
+ if overlap_blocks:
+ overlap_out = np.zeros(idp.reduced_matrix_element)
+
+ for basis_i in basis_i_list:
+ slice_i = idp.orbital_maps[symbol_i][basis_i]
+ for basis_j in basis_j_list:
+ slice_j = idp.orbital_maps[symbol_j][basis_j]
+ full_basis_i = idp.basis_to_full_basis[symbol_i][basis_i]
+ full_basis_j = idp.basis_to_full_basis[symbol_j][basis_j]
+ if idp.full_basis.index(full_basis_i) <= idp.full_basis.index(full_basis_j):
+ block_ij = block[slice_i, slice_j]
+ if overlap_blocks:
+ block_s_ij = block_s[slice_i, slice_j]
+
+ # fill hopping vector
+ pair_ij = full_basis_i + "-" + full_basis_j
+ feature_slice = idp.orbpair_maps[pair_ij]
+
+ hopping_out[feature_slice] = block_ij.flatten()
+ if overlap_blocks:
+ overlap_out[feature_slice] = block_s_ij.flatten()
+
+ edge_ham.append(hopping_out)
+ if overlap_blocks:
+ edge_overlap.append(overlap_out)
+
+ data[_keys.NODE_FEATURES_KEY] = torch.as_tensor(np.array(onsite_ham), dtype=torch.get_default_dtype())
+ data[_keys.EDGE_FEATURES_KEY] = torch.as_tensor(np.array(edge_ham), dtype=torch.get_default_dtype())
+ if overlap_blocks:
+ data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype())
+
+
+def openmx_to_deeptb(data, idp, openmx_hpath):
+ # Hamiltonian_blocks should be a h5 group in the current version
+ Us_openmx2wiki = {
+ "s": torch.eye(1).double(),
+ "p": torch.eye(3)[[1, 2, 0]].double(),
+ "d": torch.eye(5)[[2, 4, 0, 3, 1]].double(),
+ "f": torch.eye(7)[[6, 4, 2, 0, 1, 3, 5]].double()
+ }
+ # init_rot_mat
+ rot_blocks = {}
+ for asym, orbs in idp.basis.items():
+ b = [Us_openmx2wiki[re.findall(r"[A-Za-z]", orb)[0]] for orb in orbs]
+ rot_blocks[asym] = torch.block_diag(*b)
+
+
+ Hamiltonian_blocks = h5py.File(openmx_hpath, 'r')
+
+ onsite_ham = []
+ edge_ham = []
+
+ idp.get_orbital_maps()
+ idp.get_orbpair_maps()
+
+ atomic_numbers = data[_keys.ATOMIC_NUMBERS_KEY]
+
+ # onsite features
+ for atom in range(len(atomic_numbers)):
+ block_index = str([0, 0, 0, atom+1, atom+1])
+ try:
+ block = Hamiltonian_blocks[block_index][:]
+ except:
+ raise IndexError("Hamiltonian block for onsite not found, check Hamiltonian file.")
+
+
+ symbol = ase.data.chemical_symbols[atomic_numbers[atom]]
+ block = rot_blocks[symbol] @ block @ rot_blocks[symbol].T
+ basis_list = idp.basis[symbol]
+ onsite_out = np.zeros(idp.reduced_matrix_element)
+
+ for index, basis_i in enumerate(basis_list):
+ slice_i = idp.orbital_maps[symbol][basis_i]
+ for basis_j in basis_list[index:]:
+ slice_j = idp.orbital_maps[symbol][basis_j]
+ block_ij = block[slice_i, slice_j]
+ full_basis_i = idp.basis_to_full_basis[symbol][basis_i]
+ full_basis_j = idp.basis_to_full_basis[symbol][basis_j]
+
+ # fill onsite vector
+ pair_ij = full_basis_i + "-" + full_basis_j
+ feature_slice = idp.orbpair_maps[pair_ij]
+ onsite_out[feature_slice] = block_ij.flatten()
+
+ onsite_ham.append(onsite_out)
+ #onsite_ham = np.array(onsite_ham)
+
+ # edge features
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY]
+
+ for atom_i, atom_j, R_shift in zip(edge_index[0], edge_index[1], edge_cell_shift):
+ block_index = str(list(R_shift.int().numpy())+[int(atom_i)+1, int(atom_j)+1])
+ try:
+ block = Hamiltonian_blocks[block_index][:]
+ except:
+ raise IndexError("Hamiltonian block for hopping not found, r_cut may be too big for input R.")
+
+ symbol_i = ase.data.chemical_symbols[atomic_numbers[atom_i]]
+ symbol_j = ase.data.chemical_symbols[atomic_numbers[atom_j]]
+ block = rot_blocks[symbol_i] @ block @ rot_blocks[symbol_j].T
+ basis_i_list = idp.basis[symbol_i]
+ basis_j_list = idp.basis[symbol_j]
+ hopping_out = np.zeros(idp.reduced_matrix_element)
+
+ for basis_i in basis_i_list:
+ slice_i = idp.orbital_maps[symbol_i][basis_i]
+ for basis_j in basis_j_list:
+ slice_j = idp.orbital_maps[symbol_j][basis_j]
+ block_ij = block[slice_i, slice_j]
+ full_basis_i = idp.basis_to_full_basis[symbol_i][basis_i]
+ full_basis_j = idp.basis_to_full_basis[symbol_j][basis_j]
+
+ if idp.full_basis.index(full_basis_i) <= idp.full_basis.index(full_basis_j):
+ # fill hopping vector
+ pair_ij = full_basis_i + "-" + full_basis_j
+ feature_slice = idp.orbpair_maps[pair_ij]
+ hopping_out[feature_slice] = block_ij.flatten()
+
+ edge_ham.append(hopping_out)
+
+ data[_keys.NODE_FEATURES_KEY] = torch.as_tensor(np.array(onsite_ham), dtype=torch.get_default_dtype())
+ data[_keys.EDGE_FEATURES_KEY] = torch.as_tensor(np.array(edge_ham), dtype=torch.get_default_dtype())
+ Hamiltonian_blocks.close()
\ No newline at end of file
diff --git a/dptb/data/test_data.py b/dptb/data/test_data.py
new file mode 100644
index 00000000..7427cc4b
--- /dev/null
+++ b/dptb/data/test_data.py
@@ -0,0 +1,83 @@
+from typing import Optional, List, Dict, Any, Tuple
+import copy
+
+import numpy as np
+
+import ase
+import ase.build
+from ase.calculators.emt import EMT
+
+from dptb.data import AtomicInMemoryDataset, AtomicData
+from .transforms import TypeMapper
+
+
+class EMTTestDataset(AtomicInMemoryDataset):
+ """Test dataset with PBC based on the toy EMT potential included in ASE.
+
+ Randomly generates (in a reproducable manner) a basic bulk with added
+ Gaussian noise around equilibrium positions.
+
+ In ASE units (eV/Å).
+ """
+
+ def __init__(
+ self,
+ root: str,
+ supercell: Tuple[int, int, int] = (4, 4, 4),
+ sigma: float = 0.1,
+ element: str = "Cu",
+ num_frames: int = 10,
+ dataset_seed: int = 123456,
+ file_name: Optional[str] = None,
+ url: Optional[str] = None,
+ AtomicData_options: Dict[str, Any] = {},
+ include_frames: Optional[List[int]] = None,
+ type_mapper: TypeMapper = None,
+ ):
+ # Set properties for hashing
+ assert element in ("Cu", "Pd", "Au", "Pt", "Al", "Ni", "Ag")
+ self.element = element
+ self.sigma = sigma
+ self.supercell = tuple(supercell)
+ self.num_frames = num_frames
+ self.dataset_seed = dataset_seed
+
+ super().__init__(
+ file_name=file_name,
+ url=url,
+ root=root,
+ AtomicData_options=AtomicData_options,
+ include_frames=include_frames,
+ type_mapper=type_mapper,
+ )
+
+ @property
+ def raw_file_names(self):
+ return []
+
+ @property
+ def raw_dir(self):
+ return "raw"
+
+ def get_data(self):
+ rng = np.random.default_rng(self.dataset_seed)
+ base_atoms = ase.build.bulk(self.element, "fcc").repeat(self.supercell)
+ base_atoms.calc = EMT()
+ orig_pos = copy.deepcopy(base_atoms.positions)
+ datas = []
+ for _ in range(self.num_frames):
+ base_atoms.positions[:] = orig_pos
+ base_atoms.positions += rng.normal(
+ loc=0.0, scale=self.sigma, size=base_atoms.positions.shape
+ )
+
+ datas.append(
+ AtomicData.from_ase(
+ base_atoms.copy(),
+ forces=base_atoms.get_forces(),
+ total_energy=base_atoms.get_potential_energy(),
+ stress=base_atoms.get_stress(voigt=False),
+ **self.AtomicData_options
+ )
+ )
+ return datas
diff --git a/dptb/data/transforms.py b/dptb/data/transforms.py
new file mode 100644
index 00000000..4af4c4da
--- /dev/null
+++ b/dptb/data/transforms.py
@@ -0,0 +1,705 @@
+from typing import Dict, Optional, Union, List
+from dptb.data.AtomicDataDict import Type
+from dptb.utils.tools import get_uniq_symbol
+from dptb.utils.constants import anglrMId
+import re
+import warnings
+
+import torch
+
+import ase.data
+import e3nn.o3 as o3
+
+from dptb.data import AtomicData, AtomicDataDict
+
+
+class TypeMapper:
+ """Based on a configuration, map atomic numbers to types."""
+
+ num_types: int
+ chemical_symbol_to_type: Optional[Dict[str, int]]
+ type_to_chemical_symbol: Optional[Dict[int, str]]
+ type_names: List[str]
+ _min_Z: int
+
+ def __init__(
+ self,
+ type_names: Optional[List[str]] = None,
+ chemical_symbol_to_type: Optional[Dict[str, int]] = None,
+ type_to_chemical_symbol: Optional[Dict[int, str]] = None,
+ chemical_symbols: Optional[List[str]] = None,
+ device=torch.device("cpu"),
+ ):
+ self.device = device
+ if chemical_symbols is not None:
+ if chemical_symbol_to_type is not None:
+ raise ValueError(
+ "Cannot provide both `chemical_symbols` and `chemical_symbol_to_type`"
+ )
+ # repro old, sane NequIP behaviour
+ # checks also for validity of keys
+ atomic_nums = [ase.data.atomic_numbers[sym] for sym in chemical_symbols]
+ # https://stackoverflow.com/questions/29876580/how-to-sort-a-list-according-to-another-list-python
+ chemical_symbols = [
+ e[1] for e in sorted(zip(atomic_nums, chemical_symbols)) # low to high
+ ]
+ chemical_symbol_to_type = {k: i for i, k in enumerate(chemical_symbols)}
+ del chemical_symbols
+
+ if type_to_chemical_symbol is not None:
+ type_to_chemical_symbol = {
+ int(k): v for k, v in type_to_chemical_symbol.items()
+ }
+ assert all(
+ v in ase.data.chemical_symbols for v in type_to_chemical_symbol.values()
+ )
+
+ # Build from chem->type mapping, if provided
+ self.chemical_symbol_to_type = chemical_symbol_to_type
+ if self.chemical_symbol_to_type is not None:
+ # Validate
+ for sym, type in self.chemical_symbol_to_type.items():
+ assert sym in ase.data.atomic_numbers, f"Invalid chemical symbol {sym}"
+ assert 0 <= type, f"Invalid type number {type}"
+ assert set(self.chemical_symbol_to_type.values()) == set(
+ range(len(self.chemical_symbol_to_type))
+ )
+ if type_names is None:
+ # Make type_names
+ type_names = [None] * len(self.chemical_symbol_to_type)
+ for sym, type in self.chemical_symbol_to_type.items():
+ type_names[type] = sym
+ else:
+ # Make sure they agree on types
+ # We already checked that chem->type is contiguous,
+ # so enough to check length since type_names is a list
+ assert len(type_names) == len(self.chemical_symbol_to_type)
+ # Make mapper array
+ valid_atomic_numbers = [
+ ase.data.atomic_numbers[sym] for sym in self.chemical_symbol_to_type
+ ]
+ self._min_Z = min(valid_atomic_numbers)
+ self._max_Z = max(valid_atomic_numbers)
+ Z_to_index = torch.full(
+ size=(1 + self._max_Z - self._min_Z,), fill_value=-1, dtype=torch.long, device=device
+ )
+ for sym, type in self.chemical_symbol_to_type.items():
+ Z_to_index[ase.data.atomic_numbers[sym] - self._min_Z] = type
+ self._Z_to_index = Z_to_index
+ self._index_to_Z = torch.zeros(
+ size=(len(self.chemical_symbol_to_type),), dtype=torch.long, device=device
+ )
+ for sym, type_idx in self.chemical_symbol_to_type.items():
+ self._index_to_Z[type_idx] = ase.data.atomic_numbers[sym]
+ self._valid_set = set(valid_atomic_numbers)
+ true_type_to_chemical_symbol = {
+ type_id: sym for sym, type_id in self.chemical_symbol_to_type.items()
+ }
+ if type_to_chemical_symbol is not None:
+ assert type_to_chemical_symbol == true_type_to_chemical_symbol
+ else:
+ type_to_chemical_symbol = true_type_to_chemical_symbol
+
+ # check
+ if type_names is None:
+ raise ValueError(
+ "None of chemical_symbols, chemical_symbol_to_type, nor type_names was provided; exactly one is required"
+ )
+ # validate type names
+ assert all(
+ n.isalnum() for n in type_names
+ ), "Type names must contain only alphanumeric characters"
+ # Set to however many maps specified -- we already checked contiguous
+ self.num_types = len(type_names)
+ # Check type_names
+ self.type_names = type_names
+ self.type_to_chemical_symbol = type_to_chemical_symbol
+ if self.type_to_chemical_symbol is not None:
+ assert set(type_to_chemical_symbol.keys()) == set(range(self.num_types))
+
+ def __call__(
+ self, data: Union[AtomicDataDict.Type, AtomicData], types_required: bool = True
+ ) -> Union[AtomicDataDict.Type, AtomicData]:
+ if AtomicDataDict.ATOM_TYPE_KEY in data:
+ if AtomicDataDict.ATOMIC_NUMBERS_KEY in data:
+ warnings.warn(
+ "Data contained both ATOM_TYPE_KEY and ATOMIC_NUMBERS_KEY; ignoring ATOMIC_NUMBERS_KEY"
+ )
+ elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data:
+ assert (
+ self.chemical_symbol_to_type is not None
+ ), "Atomic numbers provided but there is no chemical_symbols/chemical_symbol_to_type mapping!"
+ atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY]
+ del data[AtomicDataDict.ATOMIC_NUMBERS_KEY]
+
+ data[AtomicDataDict.ATOM_TYPE_KEY] = self.transform(atomic_numbers)
+ else:
+ if types_required:
+ raise KeyError(
+ "Data doesn't contain any atom type information (ATOM_TYPE_KEY or ATOMIC_NUMBERS_KEY)"
+ )
+ return data
+
+ def transform(self, atomic_numbers):
+ """core function to transform an array to specie index list"""
+
+ if atomic_numbers.min() < self._min_Z or atomic_numbers.max() > self._max_Z:
+ bad_set = set(torch.unique(atomic_numbers).cpu().tolist()) - self._valid_set
+ raise ValueError(
+ f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
+ )
+
+ return self._Z_to_index.to(device=atomic_numbers.device)[
+ atomic_numbers - self._min_Z
+ ]
+
+ def untransform(self, atom_types):
+ """Transform atom types back into atomic numbers"""
+ return self._index_to_Z[atom_types].to(device=atom_types.device)
+
+ @property
+ def has_chemical_symbols(self) -> bool:
+ return self.chemical_symbol_to_type is not None
+
+ @staticmethod
+ def format(
+ data: list, type_names: List[str], element_formatter: str = ".6f"
+ ) -> str:
+ """
+ Formats a list of data elements along with their type names.
+
+ Parameters:
+ data (list): The data elements to be formatted. This should be a list of numbers.
+ type_names (List[str]): The type names corresponding to the data elements. This should be a list of strings.
+ element_formatter (str, optional): The format in which the data elements should be displayed. Defaults to ".6f".
+
+ Returns:
+ str: A string representation of the data elements along with their type names.
+
+ Raises:
+ ValueError: If `data` is not None, not 0-dimensional, or not 1-dimensional with length equal to the length of `type_names`.
+
+ Example:
+ >>> data = [1.123456789, 2.987654321]
+ >>> type_names = ['Type1', 'Type2']
+ >>> print(TypeMapper.format(data, type_names))
+ [Type1: 1.123457, Type2: 2.987654]
+ """
+ data = torch.as_tensor(data) if data is not None else None
+ if data is None:
+ return f"[{', '.join(type_names)}: None]"
+ elif data.ndim == 0:
+ return (f"[{', '.join(type_names)}: {{:{element_formatter}}}]").format(data)
+ elif data.ndim == 1 and len(data) == len(type_names):
+ return (
+ "["
+ + ", ".join(
+ f"{{{i}[0]}}: {{{i}[1]:{element_formatter}}}"
+ for i in range(len(data))
+ )
+ + "]"
+ ).format(*zip(type_names, data))
+ else:
+ raise ValueError(
+ f"Don't know how to format data=`{data}` for types {type_names} with element_formatter=`{element_formatter}`"
+ )
+
+
+class BondMapper(TypeMapper):
+ def __init__(
+ self,
+ chemical_symbols: Optional[List[str]] = None,
+ chemical_symbols_to_type:Union[Dict[str, int], None]=None,
+ device=torch.device("cpu"),
+ ):
+ super(BondMapper, self).__init__(chemical_symbol_to_type=chemical_symbols_to_type, chemical_symbols=chemical_symbols, device=device)
+
+ self.bond_types = [None] * self.num_types ** 2
+ self.reduced_bond_types = [None] * ((self.num_types * (self.num_types + 1)) // 2)
+ self.bond_to_type = {}
+ self.type_to_bond = {}
+ self.reduced_bond_to_type = {}
+ self.type_to_reduced_bond = {}
+ for asym, ai in self.chemical_symbol_to_type.items():
+ for bsym, bi in self.chemical_symbol_to_type.items():
+ self.bond_types[ai * self.num_types + bi] = asym + "-" + bsym
+ if ai <= bi:
+ self.reduced_bond_types[(2*self.num_types-ai+1) * ai // 2 + bi-ai] = asym + "-" + bsym
+ for i, bt in enumerate(self.bond_types):
+ self.bond_to_type[bt] = i
+ self.type_to_bond[i] = bt
+ for i, bt in enumerate(self.reduced_bond_types):
+ self.reduced_bond_to_type[bt] = i
+ self.type_to_reduced_bond[i] = bt
+
+ ZZ_to_index = torch.full(
+ size=(len(self._Z_to_index), len(self._Z_to_index)), fill_value=-1, device=device, dtype=torch.long
+ )
+ ZZ_to_reduced_index = torch.full(
+ size=(len(self._Z_to_index), len(self._Z_to_index)), fill_value=-1, device=device, dtype=torch.long
+ )
+
+
+ for abond, aidx in self.bond_to_type.items(): # type_names has a ascending order according to atomic number
+ asym, bsym = abond.split("-")
+ ZZ_to_index[ase.data.atomic_numbers[asym]-self._min_Z, ase.data.atomic_numbers[bsym]-self._min_Z] = aidx
+
+ for abond, aidx in self.reduced_bond_to_type.items(): # type_names has a ascending order according to atomic number
+ asym, bsym = abond.split("-")
+ ZZ_to_reduced_index[ase.data.atomic_numbers[asym]-self._min_Z, ase.data.atomic_numbers[bsym]-self._min_Z] = aidx
+
+
+ self._ZZ_to_index = ZZ_to_index
+ self._ZZ_to_reduced_index = ZZ_to_reduced_index
+
+ self._index_to_ZZ = torch.zeros(
+ size=(len(self.bond_to_type),2), dtype=torch.long, device=device
+ )
+ self._reduced_index_to_ZZ = torch.zeros(
+ size=(len(self.reduced_bond_to_type),2), dtype=torch.long, device=device
+ )
+
+ for abond, aidx in self.bond_to_type.items():
+ asym, bsym = abond.split("-")
+ self._index_to_ZZ[aidx] = torch.tensor([ase.data.atomic_numbers[asym], ase.data.atomic_numbers[bsym]], dtype=torch.long, device=device)
+
+ for abond, aidx in self.reduced_bond_to_type.items():
+ asym, bsym = abond.split("-")
+ self._reduced_index_to_ZZ[aidx] = torch.tensor([ase.data.atomic_numbers[asym], ase.data.atomic_numbers[bsym]], dtype=torch.long, device=device)
+
+
+ def transform_atom(self, atomic_numbers):
+ return self.transform(atomic_numbers)
+
+ def transform_bond(self, iatomic_numbers, jatomic_numbers):
+
+ if iatomic_numbers.device != jatomic_numbers.device:
+ raise ValueError("iatomic_numbers and jatomic_numbers should be on the same device!")
+
+ if iatomic_numbers.min() < self._min_Z or iatomic_numbers.max() > self._max_Z:
+ bad_set = set(torch.unique(iatomic_numbers).cpu().tolist()) - self._valid_set
+ raise ValueError(
+ f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
+ )
+
+ if jatomic_numbers.min() < self._min_Z or jatomic_numbers.max() > self._max_Z:
+ bad_set = set(torch.unique(jatomic_numbers).cpu().tolist()) - self._valid_set
+ raise ValueError(
+ f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
+ )
+
+ return self._ZZ_to_index.to(device=iatomic_numbers.device)[
+ iatomic_numbers - self._min_Z, jatomic_numbers - self._min_Z
+ ]
+
+ def transform_reduced_bond(self, iatomic_numbers, jatomic_numbers):
+
+ if iatomic_numbers.device != jatomic_numbers.device:
+ raise ValueError("iatomic_numbers and jatomic_numbers should be on the same device!")
+
+ if iatomic_numbers.min() < self._min_Z or iatomic_numbers.max() > self._max_Z:
+ bad_set = set(torch.unique(iatomic_numbers).cpu().tolist()) - self._valid_set
+ raise ValueError(
+ f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
+ )
+
+ if jatomic_numbers.min() < self._min_Z or jatomic_numbers.max() > self._max_Z:
+ bad_set = set(torch.unique(jatomic_numbers).cpu().tolist()) - self._valid_set
+ raise ValueError(
+ f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
+ )
+
+ return self._ZZ_to_reduced_index.to(device=iatomic_numbers.device)[
+ iatomic_numbers - self._min_Z, jatomic_numbers - self._min_Z
+ ]
+
+ def untransform_atom(self, atom_types):
+ """Transform atom types back into atomic numbers"""
+ return self.untransform(atom_types)
+
+ def untransform_bond(self, bond_types):
+ """Transform bond types back into atomic numbers"""
+ return self._index_to_ZZ[bond_types].to(device=bond_types.device)
+
+ def untransform_reduced_bond(self, bond_types):
+ """Transform reduced bond types back into atomic numbers"""
+ return self._reduced_index_to_ZZ[bond_types].to(device=bond_types.device)
+
+ @property
+ def has_bond(self) -> bool:
+ return self.bond_to_type is not None
+
+ def __call__(
+ self, data: Union[AtomicDataDict.Type, AtomicData], types_required: bool = True
+ ) -> Union[AtomicDataDict.Type, AtomicData]:
+ if AtomicDataDict.EDGE_TYPE_KEY in data:
+ if AtomicDataDict.ATOMIC_NUMBERS_KEY in data:
+ warnings.warn(
+ "Data contained both EDGE_TYPE_KEY and ATOMIC_NUMBERS_KEY; ignoring ATOMIC_NUMBERS_KEY"
+ )
+ elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data:
+ assert (
+ self.bond_to_type is not None
+ ), "Atomic numbers provided but there is no chemical_symbols/chemical_symbol_to_type mapping!"
+ atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY]
+
+ assert (
+ AtomicDataDict.EDGE_INDEX_KEY in data
+ ), "The bond type mapper need a EDGE index as input."
+
+ data[AtomicDataDict.EDGE_TYPE_KEY] = \
+ self.transform_bond(
+ atomic_numbers[data[AtomicDataDict.EDGE_INDEX_KEY][0]],
+ atomic_numbers[data[AtomicDataDict.EDGE_INDEX_KEY][1]]
+ )
+ else:
+ if types_required:
+ raise KeyError(
+ "Data doesn't contain any atom type information (EDGE_TYPE_KEY or ATOMIC_NUMBERS_KEY)"
+ )
+ return super().__call__(data=data, types_required=types_required)
+
+
+
+
+
+class OrbitalMapper(BondMapper):
+ def __init__(
+ self,
+ basis: Dict[str, Union[List[str], str]],
+ chemical_symbol_to_type: Optional[Dict[str, int]] = None,
+ method: str ="e3tb",
+ device: Union[str, torch.device] = torch.device("cpu")
+ ):
+
+ """
+ This class is used to map the orbital pair index to the index of the reduced matrix element (or sk integrals when method is sktb). To construct a reduced matrix element features in each edge/node with equal sizes as well as their mappings, the following steps will be conducted:
+
+ 1. The basis of each atom will be sorted according to their names. For example, The basis ["2s", "1s", "s*", "2p"] of atom A will be sorted as ["s*", "1s", "2s", "2p"].
+
+ 2. The sorted basis will be transformed into a general basis, dubbed as full_basis. It is the least required set covering all the basis number and types of each atom. The basis will be renamed according to their angular momentum and the order after sorting. Take s orbital as a example, the first s* will be named as "1s", the second s* will be named as "2s", and so on. Same for p, d, f orbitals.
+
+ Then the mappings and masks used to guide the construction of hamiltonian will be constructed. The mappings includes:
+
+ Mappings:
+ fullbasis_to_basis, basis_to_fullbasis: which function as their names
+ orbpair_maps: the mapping from orbital pairs of full basis to the reduced matrix element (or sk integrals) index.
+ orbpairtype_maps: the mapping from the types of orbital pair (e.g. "s-s", "s-p", "p-p") to the reduced matrix element (or sk integrals) index.
+ skonsite_maps: the mapping from the orbital to the sk onsite energies index.
+ skonsitetype_maps: the mapping from the orbital type (e.g. "s", "p", "d", "f") to the sk onsite energies index.
+ orbital_maps: the mapping from the orbital to the index of the corresponding lines/column in hamiltonian blocks.
+ orbpair_irreps: the e3nn irreducible representations of the full reduced matrix element edge/node features.
+
+ Masks:
+ mask_to_basis: the mask used to map the (line/column of) hamiltonian of full basis to the (line/column of) block of original basis of each atom.
+ mask_to_erme: the mask used to map the hopping block's flattened reduced matrix element (up tri-diagonal block of hamiltonian) of full basis to it of the original basis.
+ mask_to_nrme: the mask used to map the onsite block's flattened reduced matrix element (diagonal block of hamiltonian) of full basis to it of the original basis.
+
+ Parameters
+ ----------
+ basis : dict
+ the definition of the basis set, should be like:
+ {"A":"2s2p3d1f", "B":"1s2f3d1f"} or
+ {"A":["2s", "2p"], "B":["2s", "2p"]}
+ when list, "2s" indicate a "s" orbital in the second shell.
+ when str, "2s" indicates two s orbitals,
+ "2s2p3d4f" is equivilent to ["1s","2s", "1p", "2p", "1d", "2d", "3d", "1f"]
+ """
+
+ #TODO: use OrderedDict to fix the order of the dict used as index map
+ if chemical_symbol_to_type is not None:
+ assert set(basis.keys()) == set(chemical_symbol_to_type.keys())
+ super(OrbitalMapper, self).__init__(chemical_symbol_to_type=chemical_symbol_to_type, device=device)
+ else:
+ super(OrbitalMapper, self).__init__(chemical_symbols=list(basis.keys()), device=device)
+
+ self.basis = basis
+ self.method = method
+ self.device = device
+
+ if self.method not in ["e3tb", "sktb"]:
+ raise ValueError
+
+ if isinstance(self.basis[self.type_names[0]], str):
+ orbtype_count = {"s":0, "p":0, "d":0, "f":0}
+ orbs = map(lambda bs: re.findall(r'[1-9]+[A-Za-z]', bs), self.basis.values())
+ for ib in orbs:
+ for io in ib:
+ if int(io[0]) > orbtype_count[io[1]]:
+ orbtype_count[io[1]] = int(io[0])
+ # split into list basis
+ basis = {k:[] for k in self.type_names}
+ for ib in self.basis.keys():
+ for io in ["s", "p", "d", "f"]:
+ if io in self.basis[ib]:
+ basis[ib].extend([str(i)+io for i in range(1, int(re.findall(r'[1-9]+'+io, self.basis[ib])[0][0])+1)])
+ self.basis = basis
+
+ elif isinstance(self.basis[self.type_names[0]], list):
+ nb = len(self.type_names)
+ orbtype_count = {"s":[0]*nb, "p":[0]*nb, "d":[0]*nb, "f":[0]*nb}
+ for ib, bt in enumerate(self.type_names):
+ for io in self.basis[bt]:
+ orb = re.findall(r'[A-Za-z]', io)[0]
+ orbtype_count[orb][ib] += 1
+
+ for ko in orbtype_count.keys():
+ orbtype_count[ko] = max(orbtype_count[ko])
+
+ self.orbtype_count = orbtype_count
+ self.full_basis_norb = 1 * orbtype_count["s"] + 3 * orbtype_count["p"] + 5 * orbtype_count["d"] + 7 * orbtype_count["f"]
+
+
+ if self.method == "e3tb":
+ self.reduced_matrix_element = int(((orbtype_count["s"] + 9 * orbtype_count["p"] + 25 * orbtype_count["d"] + 49 * orbtype_count["f"]) + \
+ self.full_basis_norb ** 2)/2) # reduce onsite elements by blocks. we cannot reduce it by element since the rme will pass into CG basis to form the whole block
+ else:
+ # two factor: this outside one is the number of min(l,l')+1, ie. the number of sk integrals for each orbital pair.
+ # the inside one the type of bond considering the interaction between different orbitals. s-p -> p-s. there are 2 types of bond. and 1 type of s-s.
+ self.reduced_matrix_element = (
+ 1 * orbtype_count["s"] * orbtype_count["s"] + \
+ 2 * orbtype_count["s"] * orbtype_count["p"] + \
+ 2 * orbtype_count["s"] * orbtype_count["d"] + \
+ 2 * orbtype_count["s"] * orbtype_count["f"]
+ ) + \
+ 2 * (
+ 1 * orbtype_count["p"] * orbtype_count["p"] + \
+ 2 * orbtype_count["p"] * orbtype_count["d"] + \
+ 2 * orbtype_count["p"] * orbtype_count["f"]
+ ) + \
+ 3 * (
+ 1 * orbtype_count["d"] * orbtype_count["d"] + \
+ 2 * orbtype_count["d"] * orbtype_count["f"]
+ ) + \
+ 4 * (orbtype_count["f"] * orbtype_count["f"])
+
+ self.reduced_matrix_element = self.reduced_matrix_element + orbtype_count["s"] + 2*orbtype_count["p"] + 3*orbtype_count["d"] + 4*orbtype_count["f"]
+ self.reduced_matrix_element = int(self.reduced_matrix_element / 2)
+ self.n_onsite_Es = orbtype_count["s"] + orbtype_count["p"] + orbtype_count["d"] + orbtype_count["f"]
+
+ # sort the basis
+ for ib in self.basis.keys():
+ self.basis[ib] = sorted(
+ self.basis[ib],
+ key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0])
+ )
+
+ # TODO: get full basis set
+ full_basis = []
+ for io in ["s", "p", "d", "f"]:
+ full_basis = full_basis + [str(i)+io for i in range(1, orbtype_count[io]+1)]
+ self.full_basis = full_basis
+
+ # TODO: get the mapping from list basis to full basis
+ self.basis_to_full_basis = {}
+ self.atom_norb = torch.zeros(len(self.type_names), dtype=torch.long, device=self.device)
+ for ib in self.basis.keys():
+ count_dict = {"s":0, "p":0, "d":0, "f":0}
+ self.basis_to_full_basis.setdefault(ib, {})
+ for o in self.basis[ib]:
+ io = re.findall(r"[a-z]", o)[0]
+ l = anglrMId[io]
+ count_dict[io] += 1
+ self.atom_norb[self.chemical_symbol_to_type[ib]] += 2*l+1
+
+ self.basis_to_full_basis[ib][o] = str(count_dict[io])+io
+
+ # get the mapping from full basis to list basis
+ self.full_basis_to_basis = {}
+ for at, maps in self.basis_to_full_basis.items():
+ self.full_basis_to_basis[at] = {}
+ for k,v in maps.items():
+ self.full_basis_to_basis[at].update({v:k})
+
+ # Get the mask for mapping from full basis to atom specific basis
+ self.mask_to_basis = torch.zeros(len(self.type_names), self.full_basis_norb, device=self.device, dtype=torch.bool)
+
+ for ib in self.basis.keys():
+ ibasis = list(self.basis_to_full_basis[ib].values())
+ ist = 0
+ for io in self.full_basis:
+ l = anglrMId[io[1]]
+ if io in ibasis:
+ self.mask_to_basis[self.chemical_symbol_to_type[ib]][ist:ist+2*l+1] = True
+
+ ist += 2*l+1
+
+ assert (self.mask_to_basis.sum(dim=1).int()-self.atom_norb).abs().sum() <= 1e-6
+
+ self.get_orbpair_maps()
+ # the mask to map the full basis reduced matrix element to the original basis reduced matrix element
+ self.mask_to_erme = torch.zeros(len(self.bond_types), self.reduced_matrix_element, dtype=torch.bool, device=self.device)
+ self.mask_to_nrme = torch.zeros(len(self.type_names), self.reduced_matrix_element, dtype=torch.bool, device=self.device)
+ for ib, bb in self.basis.items():
+ for io in bb:
+ iof = self.basis_to_full_basis[ib][io]
+ for jo in bb:
+ jof = self.basis_to_full_basis[ib][jo]
+ if self.orbpair_maps.get(iof+"-"+jof) is not None:
+ self.mask_to_nrme[self.chemical_symbol_to_type[ib]][self.orbpair_maps[iof+"-"+jof]] = True
+
+ for ib in self.bond_to_type.keys():
+ a,b = ib.split("-")
+ for io in self.basis[a]:
+ iof = self.basis_to_full_basis[a][io]
+ for jo in self.basis[b]:
+ jof = self.basis_to_full_basis[b][jo]
+ if self.orbpair_maps.get(iof+"-"+jof) is not None:
+ self.mask_to_erme[self.bond_to_type[ib]][self.orbpair_maps[iof+"-"+jof]] = True
+ elif self.orbpair_maps.get(jof+"-"+iof) is not None:
+ self.mask_to_erme[self.bond_to_type[ib]][self.orbpair_maps[jof+"-"+iof]] = True
+
+ def get_orbpairtype_maps(self):
+ """
+ The function `get_orbpairtype_maps` creates a mapping of orbital pair types, such as s-s, "s-p",
+ to slices based on the number of hops between them.
+ :return: a dictionary called `pairtype_map`.
+ """
+
+ self.orbpairtype_maps = {}
+ ist = 0
+ for i, io in enumerate(["s", "p", "d", "f"]):
+ if self.orbtype_count[io] != 0:
+ for jo in ["s", "p", "d", "f"][i:]:
+ if self.orbtype_count[jo] != 0:
+ orb_pair = io+"-"+jo
+ il, jl = anglrMId[io], anglrMId[jo]
+ if self.method == "e3tb":
+ n_rme = (2*il+1) * (2*jl+1)
+ else:
+ n_rme = min(il, jl)+1
+ numhops = self.orbtype_count[io] * self.orbtype_count[jo] * n_rme
+ if io == jo:
+ numhops += self.orbtype_count[jo] * n_rme
+ numhops = int(numhops / 2)
+ self.orbpairtype_maps[orb_pair] = slice(ist, ist+numhops)
+
+ ist += numhops
+
+ return self.orbpairtype_maps
+
+ def get_orbpair_maps(self):
+
+ if hasattr(self, "orbpair_maps"):
+ return self.orbpair_maps
+
+ if not hasattr(self, "orbpairtype_maps"):
+ self.get_orbpairtype_maps()
+
+ self.orbpair_maps = {}
+ for i, io in enumerate(self.full_basis):
+ for jo in self.full_basis[i:]:
+ full_basis_pair = io+"-"+jo
+ ir, jr = int(full_basis_pair[0]), int(full_basis_pair[3])
+ iio, jjo = full_basis_pair[1], full_basis_pair[4]
+ il, jl = anglrMId[iio], anglrMId[jjo]
+
+ if self.method == "e3tb":
+ n_feature = (2*il+1) * (2*jl+1)
+ else:
+ n_feature = min(il, jl)+1
+ if iio == jjo:
+ start = self.orbpairtype_maps[iio+"-"+jjo].start + \
+ n_feature * ((2*self.orbtype_count[jjo]+2-ir) * (ir-1) / 2 + (jr - ir))
+ else:
+ start = self.orbpairtype_maps[iio+"-"+jjo].start + \
+ n_feature * ((ir-1)*self.orbtype_count[jjo]+(jr-1))
+ start = int(start)
+ self.orbpair_maps[io+"-"+jo] = slice(start, start+n_feature)
+
+ return self.orbpair_maps
+
+ def get_skonsite_maps(self):
+
+ assert self.method == "sktb", "Only sktb orbitalmapper have skonsite maps."
+
+ if hasattr(self, "skonsite_maps"):
+ return self.skonsite_maps
+
+ if not hasattr(self, "skonsitetype_maps"):
+ self.get_skonsitetype_maps()
+
+ self.skonsite_maps = {}
+ for i, io in enumerate(self.full_basis):
+ ir= int(io[0])
+ iio = io[1]
+
+ start = int(self.skonsitetype_maps[iio].start + (ir-1))
+ self.skonsite_maps[io] = slice(start, start+1)
+
+ return self.skonsite_maps
+
+ def get_skonsitetype_maps(self):
+ self.skonsitetype_maps = {}
+ ist = 0
+
+ assert self.method == "sktb", "Only sktb orbitalmapper have skonsite maps."
+ for i, io in enumerate(["s", "p", "d", "f"]):
+ if self.orbtype_count[io] != 0:
+ il = anglrMId[io]
+ numonsites = self.orbtype_count[io]
+
+ self.skonsitetype_maps[io] = slice(ist, ist+numonsites)
+
+ ist += numonsites
+
+ return self.skonsitetype_maps
+
+ # also need to think if we modify as this, how can we add extra basis when fitting.
+
+ def get_orbital_maps(self):
+ # simply get a 1-d slice for each atom species.
+
+ self.orbital_maps = {}
+ self.norbs = {}
+
+ for ib in self.basis.keys():
+ orbital_list = self.basis[ib]
+ slices = {}
+ start_index = 0
+
+ self.norbs.setdefault(ib, 0)
+ for orb in orbital_list:
+ orb_l = re.findall(r'[A-Za-z]', orb)[0]
+ increment = (2*anglrMId[orb_l]+1)
+ self.norbs[ib] += increment
+ end_index = start_index + increment
+
+ slices[orb] = slice(start_index, end_index)
+ start_index = end_index
+
+ self.orbital_maps[ib] = slices
+
+ return self.orbital_maps
+
+ def get_irreps(self, no_parity=False):
+ assert self.method == "e3tb", "Only support e3tb method for now."
+
+ if hasattr(self, "orbpair_irreps"):
+ if self.no_parity == no_parity:
+ return self.orbpair_irreps
+
+ self.no_parity = no_parity
+
+ if not hasattr(self, "orbpairtype_maps"):
+ self.get_orbpairtype_maps()
+
+ irs = []
+ if no_parity:
+ factor = 1
+ else:
+ factor = -1
+
+ irs = []
+ for pair, sli in self.orbpairtype_maps.items():
+ l1, l2 = anglrMId[pair[0]], anglrMId[pair[2]]
+ p = factor**(l1+l2)
+ required_ls = range(abs(l1 - l2), l1 + l2 + 1)
+ required_irreps = [(1,(l, p)) for l in required_ls]
+ irs += required_irreps*int((sli.stop-sli.start)/(2*l1+1)/(2*l2+1))
+
+ self.orbpair_irreps = o3.Irreps(irs)
+ return self.orbpair_irreps
+
+ def __eq__(self, other):
+ return self.basis == other.basis and self.method == other.method
\ No newline at end of file
diff --git a/dptb/data/use_data.ipynb b/dptb/data/use_data.ipynb
new file mode 100644
index 00000000..1d9c6584
--- /dev/null
+++ b/dptb/data/use_data.ipynb
@@ -0,0 +1,499 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from build import dataset_from_config\n",
+ "from dptb.utils.config import Config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = {\n",
+ " \"root\": \"/root/nequip_data/\",\n",
+ " \"dataset\": \"npz\",\n",
+ " \"dataset_file_name\": \"/root/nequip_data/Si8-100K.npz\",\n",
+ " \"key_mapping\":{\n",
+ " \"pos\":\"pos\",\n",
+ " \"atomic_numbers\":\"atomic_numbers\",\n",
+ " \"kpoints\": \"kpoint\",\n",
+ " \"pbc\": \"pbc\",\n",
+ " \"cell\": \"cell\",\n",
+ " \"eigenvalues\": \"eigenvalue\"\n",
+ " },\n",
+ " \"npz_fixed_field_keys\": [\"kpoint\", \"pbc\"],\n",
+ " \"graph_field\":[\"eigenvalues\"],\n",
+ " \"chemical_symbols\": [\"Si\", \"C\"],\n",
+ " \"r_max\": 6.0\n",
+ "}\n",
+ "\n",
+ "config = Config(config=config)\n",
+ "# dataset: npz # type of data set, can be npz or ase\n",
+ "# dataset_url: http://quantum-machine.org/gdml/data/npz/toluene_ccsd_t.zip # url to download the npz. optional\n",
+ "# dataset_file_name: ./benchmark_data/toluene_ccsd_t-train.npz # path to data set file\n",
+ "# key_mapping:\n",
+ "# z: atomic_numbers # atomic species, integers\n",
+ "# E: total_energy # total potential eneriges to train to\n",
+ "# F: forces # atomic forces to train to\n",
+ "# R: pos # raw atomic positions\n",
+ "# npz_fixed_field_keys: # fields that are repeated across different examples\n",
+ "# - atomic_numbers\n",
+ "\n",
+ "# chemical_symbols:\n",
+ "# - H\n",
+ "# - C"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processing dataset...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = dataset_from_config(config=config, prefix=\"dataset\")\n",
+ "\n",
+ "from dptb.data.dataloader import DataLoader\n",
+ "\n",
+ "dl = DataLoader(dataset, 3)\n",
+ "\n",
+ "data = next(iter(dl))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 1., 1., -1.],\n",
+ " [ 1., 1., 1.],\n",
+ " [ 0., 1., -1.],\n",
+ " [ 0., 1., 1.],\n",
+ " [ 1., 0., -1.],\n",
+ " [ 0., 0., -1.],\n",
+ " [ 1., 0., 1.],\n",
+ " [ 0., 0., 1.],\n",
+ " [ 0., 1., 0.],\n",
+ " [ 1., 0., 0.],\n",
+ " [ 0., 0., 0.],\n",
+ " [ 1., 1., 0.]])"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "dataset[0].edge_cell_shift[dataset[0].edge_index[0].eq(1)&dataset[0].edge_index[1].eq(2)], dataset[0].edge_cell_shift[dataset[0].edge_index[0].eq(1)&dataset[0].edge_index[1].eq(2)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([False, False, False, False, False, False, False, False, False, False,\n",
+ " True, False, False, False, False, False, False, True, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, True, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, True, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, True,\n",
+ " False, False, False, False, False, False, False, True, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, True, False, True,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " True, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, True, False, False,\n",
+ " False, True, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, True, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, True, False, False, False, True, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, True,\n",
+ " False, False, False, False, False, False, False, False, False, True,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " True, False, False, True, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, True, False, False,\n",
+ " False, False, False, True, False, False, False, False, False, True,\n",
+ " False, True, True, True])"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset[0].edge_index[0].eq(dataset[0].edge_index[1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'C-C': 0, 'C-Si': 1, 'Si-C': 2, 'Si-Si': 3}"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset.type_mapper.bond_to_type"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.nn._sktb import SKTB\n",
+ "sktb = SKTB(\n",
+ " basis={\"Si\":[\"3s\", \"3p\", \"p*\", \"s*\"], \"C\":[\"2s\",\"2p\"]},\n",
+ " onsite=\"uniform\",\n",
+ " hopping=\"powerlaw\",\n",
+ " overlap=True\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.data.AtomicDataDict import with_edge_vectors, with_onsitenv_vectors\n",
+ "\n",
+ "data = with_edge_vectors(data.to_dict())\n",
+ "data = with_onsitenv_vectors(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "data[\"atomic_numbers\"] = dataset.type_mapper.untransform(data[\"atom_types\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = sktb(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "20"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sktb.idp.edge_reduced_matrix_element"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([24, 4])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data[\"node_features\"].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.nn._hamiltonian import SKHamiltonian\n",
+ "\n",
+ "skh = SKHamiltonian(basis={\"Si\":[\"3s\", \"3p\", \"p*\", \"s*\"], \"C\":[\"2s\",\"2p\"]})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = skh(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([24, 42])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data[\"node_features\"].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.nn._hamiltonian import E3Hamiltonian\n",
+ "e3h = E3Hamiltonian(basis={\"Si\":[\"3s\", \"3p\", \"p*\", \"s*\"], \"C\":[\"2s\",\"2p\"]}, decompose=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = e3h(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ True, True, True, True, False, True, False, False, True, False,\n",
+ " False, True, False, False, True, False, False, True, False, False,\n",
+ " True, False, False, True, False, False, True, False, True, False,\n",
+ " False, False, False, False, True, False, False, True, False, False,\n",
+ " False, False, False, True, False, False, True, False, False, False,\n",
+ " False, False, True, False, False, True, False, False, False, False,\n",
+ " False, True, False, False])"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data[\"edge_features\"][0].abs().gt(1e-5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.data.AtomicData import AtomicData\n",
+ "from dptb.utils.torch_geometric import Batch\n",
+ "\n",
+ "bdata = Batch.from_dict(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "RuntimeError",
+ "evalue": "Cannot reconstruct data list from batch because the batch object was not created using `Batch.from_data_list()`.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m/root/deeptb/dptb/data/use_data.ipynb Cell 13\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m bdata\u001b[39m.\u001b[39;49mget_example(\u001b[39m0\u001b[39;49m)\n",
+ "File \u001b[0;32m/opt/miniconda/envs/deeptb/lib/python3.8/site-packages/dptb/utils/torch_geometric/batch.py:176\u001b[0m, in \u001b[0;36mBatch.get_example\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Reconstructs the :class:`torch_geometric.data.Data` object at index\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[39m:obj:`idx` from the batch object.\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[39mThe batch object must have been created via :meth:`from_data_list` in\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[39morder to be able to reconstruct the initial objects.\"\"\"\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__slices__ \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 176\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 177\u001b[0m (\n\u001b[1;32m 178\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mCannot reconstruct data list from batch because the batch \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 179\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mobject was not created using `Batch.from_data_list()`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 180\u001b[0m )\n\u001b[1;32m 181\u001b[0m )\n\u001b[1;32m 183\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__data_class__()\n\u001b[1;32m 184\u001b[0m idx \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_graphs \u001b[39m+\u001b[39m idx \u001b[39mif\u001b[39;00m idx \u001b[39m<\u001b[39m \u001b[39m0\u001b[39m \u001b[39melse\u001b[39;00m idx\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: Cannot reconstruct data list from batch because the batch object was not created using `Batch.from_data_list()`."
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.data.transforms import OrbitalMapper\n",
+ "\n",
+ "idp = OrbitalMapper(basis={\"Si\": \"2s2p1d\", \"C\":\"1s1p1d\"})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'1s-1s': slice(0, 1, None),\n",
+ " '1s-2s': slice(1, 2, None),\n",
+ " '1s-1p': slice(3, 6, None),\n",
+ " '1s-2p': slice(6, 9, None),\n",
+ " '1s-1d': slice(15, 20, None),\n",
+ " '2s-2s': slice(2, 3, None),\n",
+ " '2s-1p': slice(9, 12, None),\n",
+ " '2s-2p': slice(12, 15, None),\n",
+ " '2s-1d': slice(20, 25, None),\n",
+ " '1p-1p': slice(25, 34, None),\n",
+ " '1p-2p': slice(34, 43, None),\n",
+ " '1p-1d': slice(52, 67, None),\n",
+ " '2p-2p': slice(43, 52, None),\n",
+ " '2p-1d': slice(67, 82, None),\n",
+ " '1d-1d': slice(82, 107, None)}"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "idp.get_node_maps()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'1s-1s': slice(0, 1, None),\n",
+ " '1s-2s': slice(1, 2, None),\n",
+ " '1s-1p': slice(3, 6, None),\n",
+ " '1s-2p': slice(6, 9, None),\n",
+ " '1s-1d': slice(15, 20, None),\n",
+ " '2s-2s': slice(2, 3, None),\n",
+ " '2s-1p': slice(9, 12, None),\n",
+ " '2s-2p': slice(12, 15, None),\n",
+ " '2s-1d': slice(20, 25, None),\n",
+ " '1p-1p': slice(25, 34, None),\n",
+ " '1p-2p': slice(34, 43, None),\n",
+ " '1p-1d': slice(52, 67, None),\n",
+ " '2p-2p': slice(43, 52, None),\n",
+ " '2p-1d': slice(67, 82, None),\n",
+ " '1d-1d': slice(82, 107, None)}"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "idp.node_maps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "deeptb",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dptb/data/util.py b/dptb/data/util.py
new file mode 100644
index 00000000..494c2f8f
--- /dev/null
+++ b/dptb/data/util.py
@@ -0,0 +1,4 @@
+import torch
+
+# There is no built-in way to check if a Tensor is of an integer type
+_TORCH_INTEGER_DTYPES = (torch.int, torch.long)
diff --git a/dptb/entrypoints/__init__.py b/dptb/entrypoints/__init__.py
index 4e3c5d99..b921c542 100644
--- a/dptb/entrypoints/__init__.py
+++ b/dptb/entrypoints/__init__.py
@@ -1,5 +1,5 @@
from dptb.entrypoints.train import train
from dptb.entrypoints.config import config
-from dptb.entrypoints.run import run
+# from dptb.entrypoints.run import run
from dptb.entrypoints.test import _test as test
from dptb.entrypoints.bond import bond
\ No newline at end of file
diff --git a/dptb/entrypoints/data.py b/dptb/entrypoints/data.py
new file mode 100644
index 00000000..d19f25fb
--- /dev/null
+++ b/dptb/entrypoints/data.py
@@ -0,0 +1,39 @@
+import os
+from typing import Dict, List, Optional, Any
+from dptb.utils.tools import j_loader
+from dptb.utils.argcheck import normalize
+from dptb.data.interfaces.abacus import recursive_parse
+
+def data(
+ INPUT: str,
+ log_level: int,
+ log_path: Optional[str],
+ **kwargs
+):
+ jdata = j_loader(INPUT)
+
+ # ABACUS parsing input like:
+ # { "type": "ABACUS",
+ # "parse_arguments": {
+ # "input_path": "alice_*/*_bob/system_No_*",
+ # "preprocess_dir": "charlie/david",
+ # "only_overlap": false,
+ # "get_Hamiltonian": true,
+ # "add_overlap": true,
+ # "get_eigenvalues": true } }
+ if jdata["type"] == "ABACUS":
+ abacus_args = jdata["parse_arguments"]
+ assert abacus_args.get("input_path") is not None, "ABACUS calculation results MUST be provided."
+ assert abacus_args.get("preprocess_dir") is not None, "Please assign a dictionary to store preprocess files."
+
+ print("Begin parsing ABACUS output...")
+ recursive_parse(**abacus_args)
+ print("Finished parsing ABACUS output.")
+
+ ## write all h5 files to be used in building AtomicData
+ #with open(os.path.join(abacus_args["preprocess_dir"], "AtomicData_file.txt"), "w") as f:
+ # for filename in h5_filenames:
+ # f.write(filename + "\n")
+
+ else:
+ raise Exception("Not supported software output.")
\ No newline at end of file
diff --git a/dptb/entrypoints/main.py b/dptb/entrypoints/main.py
index 226bf685..20def5a4 100644
--- a/dptb/entrypoints/main.py
+++ b/dptb/entrypoints/main.py
@@ -8,6 +8,7 @@
from dptb.entrypoints.run import run
from dptb.entrypoints.bond import bond
from dptb.entrypoints.nrl2json import nrl2json
+from dptb.entrypoints.data import data
from dptb.utils.loggers import set_log_handles
def get_ll(log_level: str) -> int:
@@ -168,13 +169,6 @@ def main_parser() -> argparse.ArgumentParser:
help="Restart the training from the provided checkpoint.",
)
- parser_train.add_argument(
- "-sk",
- "--train-sk",
- action="store_true",
- help="Trainging NNSKTB parameters.",
- )
-
parser_train.add_argument(
"-crt",
"--use-correction",
@@ -190,13 +184,6 @@ def main_parser() -> argparse.ArgumentParser:
help="Initialize the training from the frozen model.",
)
- parser_train.add_argument(
- "-f",
- "--freeze",
- action="store_true",
- help="Initialize the training from the frozen model.",
- )
-
parser_train.add_argument(
"-o",
"--output",
@@ -226,13 +213,6 @@ def main_parser() -> argparse.ArgumentParser:
help="Initialize the model by the provided checkpoint.",
)
- parser_test.add_argument(
- "-sk",
- "--test-sk",
- action="store_true",
- help="Test NNSKTB parameters.",
- )
-
parser_test.add_argument(
"-crt",
"--use-correction",
@@ -287,13 +267,6 @@ def main_parser() -> argparse.ArgumentParser:
help="The output files in postprocess run."
)
- parser_run.add_argument(
- "-sk",
- "--run_sk",
- action="store_true",
- help="using NNSKTB parameters TB models for post-run."
- )
-
parser_run.add_argument(
"-crt",
"--use-correction",
@@ -302,6 +275,20 @@ def main_parser() -> argparse.ArgumentParser:
help="Use nnsktb correction when training dptb",
)
+ # preprocess data
+ parser_data = subparsers.add_parser(
+ "data",
+ parents=[parser_log],
+ help="preprocess software output",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser_data.add_argument(
+ "INPUT", help="the input parameter file in json or yaml format",
+ type=str,
+ default=None
+ )
+
return parser
def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
@@ -352,3 +339,5 @@ def main():
elif args.command == 'n2j':
nrl2json(**dict_args)
+ elif args.command == 'data':
+ data(**dict_args)
diff --git a/dptb/entrypoints/run.py b/dptb/entrypoints/run.py
index afc7ccf5..f09d3ee2 100644
--- a/dptb/entrypoints/run.py
+++ b/dptb/entrypoints/run.py
@@ -1,232 +1,156 @@
-import logging
-import json
-import os
-import struct
-import time
-import torch
-from pathlib import Path
-from typing import Dict, List, Optional, Any
-from dptb.plugins.train_logger import Logger
-from dptb.plugins.init_nnsk import InitSKModel
-from dptb.plugins.init_dptb import InitDPTBModel
-from dptb.utils.argcheck import normalize, normalize_run
-from dptb.utils.tools import j_loader
-from dptb.utils.loggers import set_log_handles
-from dptb.utils.tools import j_must_have
-from dptb.utils.constants import dtype_dict
-from dptb.nnops.apihost import NNSKHost, DPTBHost
-from dptb.nnops.NN2HRK import NN2HRK
-from ase.io import read,write
-from dptb.postprocess.bandstructure.band import bandcalc
-from dptb.postprocess.bandstructure.dos import doscalc, pdoscalc
-from dptb.postprocess.bandstructure.fermisurface import fs2dcalc, fs3dcalc
-from dptb.postprocess.bandstructure.ifermi_api import ifermiapi, ifermi_installed, pymatgen_installed
-from dptb.postprocess.write_skparam import WriteNNSKParam
-from dptb.postprocess.NEGF import NEGF
-from dptb.postprocess.tbtrans_init import TBTransInputSet,sisl_installed
-
-__all__ = ["run"]
-
-log = logging.getLogger(__name__)
-
-def run(
- INPUT: str,
- init_model: str,
- output: str,
- run_sk: bool,
- structure: str,
- log_level: int,
- log_path: Optional[str],
- use_correction: Optional[str],
- **kwargs
- ):
+# import logging
+# import json
+# import os
+# import time
+# import torch
+# from pathlib import Path
+# from typing import Optional
+# from dptb.plugins.train_logger import Logger
+# from dptb.utils.argcheck import normalize_run
+# from dptb.utils.tools import j_loader
+# from dptb.utils.loggers import set_log_handles
+# from dptb.utils.tools import j_must_have
+# from dptb.nn.build import build_model
+# from dptb.postprocess.bandstructure.band import bandcalc
+# from dptb.postprocess.bandstructure.dos import doscalc, pdoscalc
+# from dptb.postprocess.bandstructure.fermisurface import fs2dcalc, fs3dcalc
+# from dptb.postprocess.bandstructure.ifermi_api import ifermiapi, ifermi_installed, pymatgen_installed
+# from dptb.postprocess.write_skparam import WriteNNSKParam
+# from dptb.postprocess.NEGF import NEGF
+# from dptb.postprocess.tbtrans_init import TBTransInputSet,sisl_installed
+
+# __all__ = ["run"]
+
+# log = logging.getLogger(__name__)
+
+# def run(
+# INPUT: str,
+# init_model: str,
+# output: str,
+# structure: str,
+# log_level: int,
+# log_path: Optional[str],
+# **kwargs
+# ):
- run_opt = {
- "run_sk": run_sk,
- "init_model":init_model,
- "structure":structure,
- "log_path": log_path,
- "log_level": log_level,
- "use_correction":use_correction
- }
-
- if all((use_correction, run_sk)):
- log.error(msg="--use-correction and --train_sk should not be set at the same time")
- raise RuntimeError
+# run_opt = {
+# "init_model":init_model,
+# "structure":structure,
+# "log_path": log_path,
+# "log_level": log_level,
+# }
+
+# # output folder.
+# if output:
+
+# Path(output).parent.mkdir(exist_ok=True, parents=True)
+# Path(output).mkdir(exist_ok=True, parents=True)
+# results_path = os.path.join(str(output), "results")
+# Path(results_path).mkdir(exist_ok=True, parents=True)
+# if not log_path:
+# log_path = os.path.join(str(output), "log/log.txt")
+# Path(log_path).parent.mkdir(exist_ok=True, parents=True)
+
+# run_opt.update({
+# "output": str(Path(output).absolute()),
+# "results_path": str(Path(results_path).absolute()),
+# "log_path": str(Path(log_path).absolute())
+# })
- jdata = j_loader(INPUT)
- jdata = normalize_run(jdata)
+# jdata = j_loader(INPUT)
+# jdata = normalize_run(jdata)
- if all((jdata["init_model"]["path"], run_opt["init_model"])):
- raise RuntimeError(
- "init-model in config and command line is in conflict, turn off one of then to avoid this error !"
- )
-
- if run_opt["init_model"] is None:
- log.info(msg="model_ckpt is not set in command line, read from input config file.")
-
- if run_sk:
- if jdata["init_model"]["path"] is not None:
- run_opt["init_model"] = jdata["init_model"]
- else:
- log.error(msg="Error! init_model is not set in config file and command line.")
- raise RuntimeError
- if isinstance(run_opt["init_model"]["path"], list):
- if len(run_opt["init_model"]["path"])==0:
- log.error("Error! list mode init_model in config file cannot be empty!")
- raise RuntimeError
- else:
- if jdata["init_model"]["path"] is not None:
- run_opt["init_model"] = jdata["init_model"]
- else:
- log.error(msg="Error! init_model is not set in config file and command line.")
- raise RuntimeError
- if isinstance(run_opt["init_model"]["path"], list):
- raise RuntimeError(
- "loading lists of checkpoints is only supported in init_nnsk!"
- )
- if isinstance(run_opt["init_model"]["path"], list):
- if len(run_opt["init_model"]["path"]) == 1:
- run_opt["init_model"]["path"] = run_opt["init_model"]["path"][0]
- else:
- path = run_opt["init_model"]
- run_opt["init_model"] = jdata["init_model"]
- run_opt["init_model"]["path"] = path
+# set_log_handles(log_level, Path(log_path) if log_path else None)
+
+# f = torch.load(run_opt["init_model"])
+# jdata["model_options"] = f["config"]["model_options"]
+# del f
- task_options = j_must_have(jdata, "task_options")
- task = task_options["task"]
- use_gui = jdata.get("use_gui", False)
- task_options.update({"use_gui": use_gui})
-
- model_ckpt = run_opt["init_model"]["path"]
- # init_type = model_ckpt.split(".")[-1]
- # if init_type not in ["json", "pth"]:
- # log.error(msg="Error! the model file should be a json or pth file.")
- # raise RuntimeError
-
- # if init_type == "json":
- # jdata = host_normalize(jdata)
- # if run_sk:
- # jdata.update({"init_model": {"path": model_ckpt,"interpolate": False}})
- # else:
- # jdata.update({"init_model": model_ckpt})
-
- if run_opt['structure'] is None:
- log.warning(msg="Warning! structure is not set in run option, read from input config file.")
- structure = j_must_have(jdata, "structure")
- run_opt.update({"structure":structure})
-
- print(run_opt["structure"])
-
- if not run_sk:
- if run_opt['use_correction'] is None and jdata.get('use_correction',None) != None:
- use_correction = jdata['use_correction']
- run_opt.update({"use_correction":use_correction})
- log.info(msg="use_correction is not set in run option, read from input config file.")
-
- # output folder.
- if output:
-
- Path(output).parent.mkdir(exist_ok=True, parents=True)
- Path(output).mkdir(exist_ok=True, parents=True)
- results_path = os.path.join(str(output), "results")
- Path(results_path).mkdir(exist_ok=True, parents=True)
- if not log_path:
- log_path = os.path.join(str(output), "log/log.txt")
- Path(log_path).parent.mkdir(exist_ok=True, parents=True)
-
- run_opt.update({
- "output": str(Path(output).absolute()),
- "results_path": str(Path(results_path).absolute()),
- "log_path": str(Path(log_path).absolute())
- })
-
- set_log_handles(log_level, Path(log_path) if log_path else None)
-
- if jdata.get("common_options", None):
- # in this case jdata must have common options
- str_dtype = jdata["common_options"]["dtype"]
- # jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]]
-
- if run_sk:
- apihost = NNSKHost(checkpoint=model_ckpt, config=jdata)
- apihost.register_plugin(InitSKModel())
- apihost.build()
- apiHrk = NN2HRK(apihost=apihost, mode='nnsk')
- else:
- apihost = DPTBHost(dptbmodel=model_ckpt,use_correction=use_correction)
- apihost.register_plugin(InitDPTBModel())
- apihost.build()
- apiHrk = NN2HRK(apihost=apihost, mode='dptb')
+# task_options = j_must_have(jdata, "task_options")
+# task = task_options["task"]
+# use_gui = jdata.get("use_gui", False)
+# task_options.update({"use_gui": use_gui})
+
+# if run_opt['structure'] is None:
+# log.warning(msg="Warning! structure is not set in run option, read from input config file.")
+# structure = j_must_have(jdata, "structure")
+# run_opt.update({"structure":structure})
+# else:
+# structure = run_opt["structure"]
+
+# print(run_opt["structure"])
+
+# if jdata.get("common_options", None):
+# # in this case jdata must have common options
+# str_dtype = jdata["common_options"]["dtype"]
+# # jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]]
+
+# model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"])
-
- # one can just add his own function to calculate properties by add a task, and its code to calculate.
-
- if task=='band':
- # TODO: add argcheck for bandstructure, with different options. see, kline_mode: ase, vasp, abacus, etc.
- bcal = bandcalc(apiHrk, run_opt, task_options)
- bcal.get_bands()
- bcal.band_plot()
- log.info(msg='band calculation successfully completed.')
-
- if task=='dos':
- bcal = doscalc(apiHrk, run_opt, task_options)
- bcal.get_dos()
- bcal.dos_plot()
- log.info(msg='dos calculation successfully completed.')
-
- if task=='pdos':
- bcal = pdoscalc(apiHrk, run_opt, task_options)
- bcal.get_pdos()
- bcal.pdos_plot()
- log.info(msg='pdos calculation successfully completed.')
+# # one can just add his own function to calculate properties by add a task, and its code to calculate.
+# if task=='band':
+# # TODO: add argcheck for bandstructure, with different options. see, kline_mode: ase, vasp, abacus, etc.
+# bcal = bandcalc(model, structure, task_options)
+# bcal.get_bands()
+# bcal.band_plot()
+# log.info(msg='band calculation successfully completed.')
+
+# if task=='dos':
+# bcal = doscalc(model, structure, task_options)
+# bcal.get_dos()
+# bcal.dos_plot()
+# log.info(msg='dos calculation successfully completed.')
+
+# if task=='pdos':
+# bcal = pdoscalc(model, structure, task_options)
+# bcal.get_pdos()
+# bcal.pdos_plot()
+# log.info(msg='pdos calculation successfully completed.')
- if task=='FS2D':
- fs2dcal = fs2dcalc(apiHrk, run_opt, task_options)
- fs2dcal.get_fs()
- fs2dcal.fs2d_plot()
- log.info(msg='2dFS calculation successfully completed.')
+# if task=='FS2D':
+# fs2dcal = fs2dcalc(model, structure, task_options)
+# fs2dcal.get_fs()
+# fs2dcal.fs2d_plot()
+# log.info(msg='2dFS calculation successfully completed.')
- if task == 'FS3D':
- fs3dcal = fs3dcalc(apiHrk, run_opt, task_options)
- fs3dcal.get_fs()
- fs3dcal.fs_plot()
- log.info(msg='3dFS calculation successfully completed.')
+# if task == 'FS3D':
+# fs3dcal = fs3dcalc(model, structure, task_options)
+# fs3dcal.get_fs()
+# fs3dcal.fs_plot()
+# log.info(msg='3dFS calculation successfully completed.')
- if task == 'ifermi':
- if not(ifermi_installed and pymatgen_installed):
- log.error(msg="ifermi and pymatgen are required to perform ifermi calculation !")
- raise RuntimeError
-
- ifermi = ifermiapi(apiHrk, run_opt, task_options)
- bs = ifermi.get_band_structure()
- fs = ifermi.get_fs(bs)
- ifermi.fs_plot(fs)
- log.info(msg='Ifermi calculation successfully completed.')
- if task == 'write_sk':
- if not run_sk:
- raise RuntimeError("write_sk can only perform on nnsk model !")
- write_sk = WriteNNSKParam(apiHrk, run_opt, task_options)
- write_sk.write()
- log.info(msg='write_sk calculation successfully completed.')
-
- if task == 'negf':
- negf = NEGF(apiHrk, run_opt, task_options)
- negf.compute()
- log.info(msg='NEGF calculation successfully completed.')
-
- if task == 'tbtrans_negf':
- if not(sisl_installed):
- log.error(msg="sisl is required to perform tbtrans calculation !")
- raise RuntimeError
-
- tbtrans_init = TBTransInputSet(apiHrk, run_opt, task_options)
- tbtrans_init.hamil_get_write(write_nc=True)
- log.info(msg='TBtrans input files are successfully generated.')
-
- if output:
- with open(os.path.join(output, "run_config.json"), "w") as fp:
- if jdata.get("common_options", None):
- jdata["common_options"]["dtype"] = str_dtype
- json.dump(jdata, fp, indent=4)
+# if task == 'ifermi':
+# if not(ifermi_installed and pymatgen_installed):
+# log.error(msg="ifermi and pymatgen are required to perform ifermi calculation !")
+# raise RuntimeError
+
+# ifermi = ifermiapi(model, structure, task_options)
+# bs = ifermi.get_band_structure()
+# fs = ifermi.get_fs(bs)
+# ifermi.fs_plot(fs)
+# log.info(msg='Ifermi calculation successfully completed.')
+# if task == 'write_sk':
+# if not jdata["model_options"].keys()[0] == "nnsk" or len(jdata["model_options"].keys()) > 1:
+# raise RuntimeError("write_sk can only perform on nnsk model !")
+# write_sk = WriteNNSKParam(model, structure, task_options)
+# write_sk.write()
+# log.info(msg='write_sk calculation successfully completed.')
+
+# if task == 'negf':
+# negf = NEGF(model, structure, task_options)
+# negf.compute()
+# log.info(msg='NEGF calculation successfully completed.')
+
+# if task == 'tbtrans_negf':
+# if not(sisl_installed):
+# log.error(msg="sisl is required to perform tbtrans calculation !")
+# raise RuntimeError
+
+# tbtrans_init = TBTransInputSet(apiHrk, run_opt, task_options)
+# tbtrans_init.hamil_get_write(write_nc=True)
+# log.info(msg='TBtrans input files are successfully generated.')
+
+# if output:
+# with open(os.path.join(output, "run_config.json"), "w") as fp:
+# json.dump(jdata, fp, indent=4)
diff --git a/dptb/entrypoints/test.py b/dptb/entrypoints/test.py
index 7380ef8c..9874f7b3 100644
--- a/dptb/entrypoints/test.py
+++ b/dptb/entrypoints/test.py
@@ -1,21 +1,16 @@
import heapq
import logging
import torch
-import random
import json
import os
import time
-import numpy as np
from pathlib import Path
-from dptb.nnops.tester_dptb import DPTBTester
-from dptb.nnops.tester_nnsk import NNSKTester
-from typing import Dict, List, Optional, Any
+from dptb.nn.build import build_model
+from dptb.data.build import build_dataset
+from typing import Optional
from dptb.utils.loggers import set_log_handles
from dptb.utils.tools import j_loader, setup_seed
-from dptb.utils.constants import dtype_dict
-from dptb.plugins.init_nnsk import InitSKModel
-from dptb.plugins.init_dptb import InitDPTBModel
-from dptb.plugins.init_data import InitTestData
+from dptb.nnops.tester import Tester
from dptb.utils.argcheck import normalize_test
from dptb.plugins.monitor import TestLossMonitor
from dptb.plugins.train_logger import Logger
@@ -30,7 +25,6 @@ def _test(
output: str,
log_level: int,
log_path: Optional[str],
- test_sk: bool,
use_correction: Optional[str],
**kwargs
):
@@ -39,98 +33,11 @@ def _test(
"init_model": init_model,
"log_path": log_path,
"log_level": log_level,
- "test_sk": test_sk,
"use_correction": use_correction,
"freeze":True,
"train_soc":False
}
- if all((use_correction, test_sk)):
- log.error(msg="--use-correction and --train_sk should not be set at the same time")
- raise RuntimeError
-
- # setup INPUT path
- if test_sk:
- if init_model:
- skconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_nnsktb.json")
- mode = "init_model"
- else:
- log.error("ValueError: Missing init_model file path.")
- raise ValueError
- jdata = j_loader(INPUT)
- jdata = normalize_test(jdata)
-
- if all((jdata["init_model"]["path"], run_opt["init_model"])):
- raise RuntimeError(
- "init-model in config and command line is in conflict, turn off one of then to avoid this error !"
- )
- else:
- if jdata["init_model"]["path"] is not None:
- assert mode == "from_scratch"
- run_opt["init_model"] = jdata["init_model"]
- mode = "init_model"
- if isinstance(run_opt["init_model"]["path"], str):
- skconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_nnsktb.json")
- else: # list
- skconfig_path = [os.path.join(str(Path(path).parent.absolute()), "config_nnsktb.json") for path in run_opt["init_model"]["path"]]
- else:
- if run_opt["init_model"] is not None:
- assert mode == "init_model"
- path = run_opt["init_model"]
- run_opt["init_model"] = jdata["init_model"]
- run_opt["init_model"]["path"] = path
- else:
- if init_model:
- dptbconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_dptbtb.json")
- mode = "init_model"
- else:
- log.error("ValueError: Missing init_model file path.")
- raise ValueError
-
- if use_correction:
- skconfig_path = os.path.join(str(Path(use_correction).parent.absolute()), "config_nnsktb.json")
- else:
- skconfig_path = None
-
- jdata = j_loader(INPUT)
- jdata = normalize_test(jdata)
-
- if all((jdata["init_model"]["path"], run_opt["init_model"])):
- raise RuntimeError(
- "init-model in config and command line is in conflict, turn off one of then to avoid this error !"
- )
-
- if jdata["init_model"]["path"] is not None:
- assert mode == "from_scratch"
- log.info(msg="Init model is read from config rile.")
- run_opt["init_model"] = jdata["init_model"]
- mode = "init_model"
- if isinstance(run_opt["init_model"]["path"], str):
- dptbconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_dptb.json")
- else: # list
- raise RuntimeError(
- "loading lists of checkpoints is only supported in init_nnsk!"
- )
- elif run_opt["init_model"] is not None:
- assert mode == "init_model"
- path = run_opt["init_model"]
- run_opt["init_model"] = jdata["init_model"]
- run_opt["init_model"]["path"] = path
-
- if mode == "init_model":
- if isinstance(run_opt["init_model"]["path"], list):
- if len(run_opt["init_model"]["path"])==0:
- log.error(msg="Error, no checkpoint supplied!")
- raise RuntimeError
- elif len(run_opt["init_model"]["path"])>1:
- log.error(msg="Error! list mode init_model in config only support single file in DPTB!")
- raise RuntimeError
-
- if mode == "init_model":
- if isinstance(run_opt["init_model"]["path"], list):
- if len(run_opt["init_model"]["path"]) == 1:
- run_opt["init_model"]["path"] = run_opt["init_model"]["path"][0]
-
# setup output path
if output:
Path(output).parent.mkdir(exist_ok=True, parents=True)
@@ -146,42 +53,38 @@ def _test(
"results_path": str(Path(results_path).absolute()),
"log_path": str(Path(log_path).absolute())
})
- run_opt.update({"mode": mode})
- if test_sk:
- run_opt.update({
- "skconfig_path": skconfig_path,
- })
- else:
- if use_correction:
- run_opt.update({
- "skconfig_path": skconfig_path
- })
- run_opt.update({
- "dptbconfig_path": dptbconfig_path
- })
-
- set_log_handles(log_level, Path(log_path) if log_path else None)
+
+ jdata = j_loader(INPUT)
+ jdata = normalize_test(jdata)
+ # setup seed
+ setup_seed(seed=jdata["common_options"]["seed"])
- # setup_seed(seed=jdata["train_options"]["seed"])
+ f = torch.load(init_model)
+ # update basis
+ basis = f["config"]["common_options"]["basis"]
+ for asym, orb in jdata["common_options"]["basis"].items():
+ assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
+ assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis"
+ jdata["common_options"]["basis"] = basis # use the old basis, because it will be used to build the orbital mapper for dataset
- # with open(os.path.join(output, "test_config.json"), "w") as fp:
- # json.dump(jdata, fp, indent=4)
-
+ set_log_handles(log_level, Path(log_path) if log_path else None)
- str_dtype = jdata["common_options"]["dtype"]
- jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]]
+ f = torch.load(run_opt["init_model"])
+ jdata["model_options"] = f["config"]["model_options"]
+ del f
+ test_datasets = build_dataset(set_options=jdata["data_options"]["test"], common_options=jdata["common_options"])
+ model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"])
+ model.eval()
+ tester = Tester(
+ test_options=jdata["test_options"],
+ common_options=jdata["common_options"],
+ model = model,
+ test_datasets=test_datasets,
+ )
- if test_sk:
- tester = NNSKTester(run_opt, jdata)
- tester.register_plugin(InitSKModel())
- else:
- tester = DPTBTester(run_opt, jdata)
- tester.register_plugin(InitDPTBModel())
-
# register the plugin in tester, to tract training info
- tester.register_plugin(InitTestData())
tester.register_plugin(TestLossMonitor())
tester.register_plugin(Logger(["test_loss"],
interval=[(1, 'iteration'), (1, 'epoch')]))
@@ -194,20 +97,12 @@ def _test(
if output:
# output training configurations:
with open(os.path.join(output, "test_config.json"), "w") as fp:
- jdata["common_options"]["dtype"] = str_dtype
json.dump(jdata, fp, indent=4)
- #tester.register_plugin(Saver(
- #interval=[(jdata["train_options"].get("save_freq"), 'epoch'), (1, 'iteration')] if jdata["train_options"].get(
- # "save_freq") else None))
- # interval=[(jdata["train_options"].get("save_freq"), 'iteration'), (1, 'epoch')] if jdata["train_options"].get(
- # "save_freq") else None))
- # add a plugin to save the training parameters of the model, with model_output as given path
-
start_time = time.time()
- tester.run(epochs=1)
+ tester.run()
end_time = time.time()
log.info("finished testing")
- log.info(f"wall time: {(end_time - start_time):.3f} s")
+ log.info(f"wall time: {(end_time - start_time):.3f} s")
\ No newline at end of file
diff --git a/dptb/entrypoints/train.py b/dptb/entrypoints/train.py
index b1be08a3..d58b6710 100644
--- a/dptb/entrypoints/train.py
+++ b/dptb/entrypoints/train.py
@@ -1,14 +1,12 @@
-from dptb.nnops.train_dptb import DPTBTrainer
-from dptb.nnops.train_nnsk import NNSKTrainer
+from dptb.nnops.trainer import Trainer
+from dptb.nn.build import build_model
+from dptb.data.build import build_dataset
from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer
-from dptb.plugins.init_nnsk import InitSKModel
-from dptb.plugins.init_dptb import InitDPTBModel
-from dptb.plugins.init_data import InitData
from dptb.plugins.train_logger import Logger
from dptb.utils.argcheck import normalize
from dptb.plugins.plugins import Saver
from typing import Dict, List, Optional, Any
-from dptb.utils.tools import j_loader, setup_seed
+from dptb.utils.tools import j_loader, setup_seed, j_must_have
from dptb.utils.constants import dtype_dict
from dptb.utils.loggers import set_log_handles
import heapq
@@ -30,30 +28,24 @@ def train(
INPUT: str,
init_model: Optional[str],
restart: Optional[str],
- freeze:bool,
train_soc:bool,
output: str,
log_level: int,
log_path: Optional[str],
- train_sk: bool,
- use_correction: Optional[str],
**kwargs
):
run_opt = {
"init_model": init_model,
"restart": restart,
- "freeze": freeze,
"train_soc": train_soc,
"log_path": log_path,
- "log_level": log_level,
- "train_sk": train_sk,
- "use_correction": use_correction
+ "log_level": log_level
}
+ assert train_soc is False, "train_soc is not supported yet"
+
'''
-1- set up input and output directories
- noticed that, the checkpoint of sktb and dptb should be in different directory, and in train_dptb,
- there should be a workflow to load correction model from nnsktb.
-2- parse configuration file and start training
output directories has following structure:
@@ -65,132 +57,18 @@ def train(
...
- log/
- log.log
- - config_nnsktb.json
- - config_dptb.json
+ - config.json
'''
# init all paths
# if init_model, restart or init_frez, findout the input configure file
-
- if all((use_correction, train_sk)):
- raise RuntimeError(
- "--use-correction and --train_sk should not be set at the same time"
- )
# setup INPUT path
- if train_sk:
- if init_model:
- skconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_nnsktb.json")
- mode = "init_model"
- elif restart:
- skconfig_path = os.path.join(str(Path(restart).parent.absolute()), "config_nnsktb.json")
- mode = "restart"
- elif INPUT is not None:
- log.info(msg="Haven't assign a initializing mode, training from scratch as default.")
- mode = "from_scratch"
- skconfig_path = INPUT
- else:
- log.error("ValueError: Missing Input configuration file path.")
- raise ValueError
-
- # switch the init model mode from command line to config file
- jdata = j_loader(INPUT)
- jdata = normalize(jdata)
-
- # check if init_model in commandline and input json are in conflict.
-
- if all((jdata["init_model"]["path"], run_opt["init_model"])) or \
- all((jdata["init_model"]["path"], run_opt["restart"])):
- raise RuntimeError(
- "init-model in config and command line is in conflict, turn off one of then to avoid this error !"
- )
-
- if jdata["init_model"]["path"] is not None:
- assert mode == "from_scratch"
- run_opt["init_model"] = jdata["init_model"]
- mode = "init_model"
- if isinstance(run_opt["init_model"]["path"], str):
- skconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_nnsktb.json")
- else: # list
- skconfig_path = [os.path.join(str(Path(path).parent.absolute()), "config_nnsktb.json") for path in run_opt["init_model"]["path"]]
- elif run_opt["init_model"] is not None:
- # format run_opt's init model to the format of jdata
- assert mode == "init_model"
- path = run_opt["init_model"]
- run_opt["init_model"] = jdata["init_model"]
- run_opt["init_model"]["path"] = path
-
- # handling exceptions when init_model path in config file is [] and [single file]
- if mode == "init_model":
- if isinstance(run_opt["init_model"]["path"], list):
- if len(run_opt["init_model"]["path"])==0:
- raise RuntimeError("Error! list mode init_model in config file cannot be empty!")
-
- else:
- if init_model:
- dptbconfig_path = os.path.join(str(Path(init_model).parent.absolute()), "config_dptbtb.json")
- mode = "init_model"
- elif restart:
- dptbconfig_path = os.path.join(str(Path(restart).parent.absolute()), "config_dptbtb.json")
- mode = "restart"
- elif INPUT is not None:
- log.info(msg="Haven't assign a initializing mode, training from scratch as default.")
- dptbconfig_path = INPUT
- mode = "from_scratch"
- else:
- log.error("ValueError: Missing Input configuration file path.")
- raise ValueError
-
- if use_correction:
- skconfig_path = os.path.join(str(Path(use_correction).parent.absolute()), "config_nnsktb.json")
- # skcheckpoint_path = str(Path(str(input(f"Enter skcheckpoint_path (default ./checkpoint/best_nnsk.pth): \n"))).absolute())
- else:
- skconfig_path = None
-
- # parse INPUT file
- jdata = j_loader(INPUT)
- jdata = normalize(jdata)
-
- if all((jdata["init_model"]["path"], run_opt["init_model"])) or \
- all((jdata["init_model"]["path"], run_opt["restart"])):
- raise RuntimeError(
- "init-model in config and command line is in conflict, turn off one of then to avoid this error !"
- )
-
- if jdata["init_model"]["path"] is not None:
- assert mode == "from_scratch"
- log.info(msg="Init model is read from config rile.")
- run_opt["init_model"] = jdata["init_model"]
- mode = "init_model"
- if isinstance(run_opt["init_model"]["path"], str):
- dptbconfig_path = os.path.join(str(Path(run_opt["init_model"]["path"]).parent.absolute()), "config_dptb.json")
- else: # list
- raise RuntimeError(
- "loading lists of checkpoints is only supported in init_nnsk!"
- )
- elif run_opt["init_model"] is not None:
- assert mode == "init_model"
- path = run_opt["init_model"]
- run_opt["init_model"] = jdata["init_model"]
- run_opt["init_model"]["path"] = path
-
- if mode == "init_model":
- if isinstance(run_opt["init_model"]["path"], list):
- if len(run_opt["init_model"]["path"])==0:
- log.error(msg="Error, no checkpoint supplied!")
- raise RuntimeError
- elif len(run_opt["init_model"]["path"])>1:
- log.error(msg="Error! list mode init_model in config only support single file in DPTB!")
- raise RuntimeError
if all((run_opt["init_model"], restart)):
raise RuntimeError(
"--init-model and --restart should not be set at the same time"
)
- if mode == "init_model":
- if isinstance(run_opt["init_model"]["path"], list):
- if len(run_opt["init_model"]["path"]) == 1:
- run_opt["init_model"]["path"] = run_opt["init_model"]["path"][0]
# setup output path
if output:
Path(output).parent.mkdir(exist_ok=True, parents=True)
@@ -209,71 +87,137 @@ def train(
"log_path": str(Path(log_path).absolute())
})
- run_opt.update({"mode": mode})
- if train_sk:
- run_opt.update({
- "skconfig_path": skconfig_path,
- })
- else:
- if use_correction:
- run_opt.update({
- "skconfig_path": skconfig_path
- })
- run_opt.update({
- "dptbconfig_path": dptbconfig_path
- })
-
set_log_handles(log_level, Path(log_path) if log_path else None)
# parse the config. Since if use init, config file may not equals to current
+ jdata = j_loader(INPUT)
+ jdata = normalize(jdata)
+ # update basis if init_model or restart
+ # update jdata
+ # this is not necessary, because if we init model from checkpoint, the build_model will load the model_options from checkpoints if not provided
+ # since here we want to output jdata as a config file to inform the user what model options are used, we need to update the jdata
+ torch.set_default_dtype(getattr(torch, jdata["common_options"]["dtype"]))
+
+ if restart or init_model:
+ f = restart if restart else init_model
+ f = torch.load(f)
+
+ if jdata.get("model_options", None) is None:
+ jdata["model_options"] = f["config"]["model_options"]
+
+ # update basis
+ basis = f["config"]["common_options"]["basis"]
+ # nnsk
+ if len(f["config"]["model_options"])==1 and f["config"]["model_options"].get("nnsk") != None:
+ for asym, orb in jdata["common_options"]["basis"].items():
+ assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
+ if orb != basis[asym]:
+ log.info(f"Initializing Orbital {orb} of Atom {asym} from {basis[asym]}")
+ # we have the orbitals in jdata basis correct, now we need to make sure all atom in basis are also contained in jdata basis
+ for asym, orb in basis.items():
+ if asym not in jdata["common_options"]["basis"].keys():
+ jdata["common_options"]["basis"][asym] = orb # add the atomtype in the checkpoint but not in the jdata basis, because it will be used to build the orbital mapper for dataset
+ else: # not nnsk
+ for asym, orb in jdata["common_options"]["basis"].items():
+ assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
+ assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis, which is only allowed in nnsk training"
+
+ jdata["common_options"]["basis"] = basis
+
+ # update model options and train_options
+ if restart:
+ #
+ if jdata.get("train_options", None) is not None:
+ for obj in Trainer.object_keys:
+ if jdata["train_options"].get(obj) != f["config"]["train_options"].get(obj):
+ log.warning(f"{obj} in config file is not consistent with the checkpoint, using the one in checkpoint")
+ jdata["train_options"][obj] = f["config"]["train_options"][obj]
+ else:
+ jdata["train_options"] = f["config"]["train_options"]
+
+ if jdata.get("model_options", None) is None or jdata["model_options"] != f["config"]["model_options"]:
+ log.warning("model_options in config file is not consistent with the checkpoint, using the one in checkpoint")
+ jdata["model_options"] = f["config"]["model_options"] # restart does not allow to change model options
+ else:
+ # init model mode, allow model_options change
+ if jdata.get("train_options", None) is None:
+ jdata["train_options"] = f["config"]["train_options"]
+ if jdata.get("model_options") is None:
+ jdata["model_options"] = f["config"]["model_options"]
+ del f
+ else:
+ j_must_have(jdata, "model_options")
+ j_must_have(jdata, "train_options")
+
+
# setup seed
- setup_seed(seed=jdata["train_options"]["seed"])
+ setup_seed(seed=jdata["common_options"]["seed"])
# with open(os.path.join(output, "train_config.json"), "w") as fp:
# json.dump(jdata, fp, indent=4)
-
- str_dtype = jdata["common_options"]["dtype"]
- jdata["common_options"]["dtype"] = dtype_dict[jdata["common_options"]["dtype"]]
- if train_sk:
- trainer = NNSKTrainer(run_opt, jdata)
- trainer.register_plugin(InitSKModel())
+
+ # build dataset
+ train_datasets = build_dataset(set_options=jdata["data_options"]["train"], common_options=jdata["common_options"])
+ if jdata["data_options"].get("validation"):
+ validation_datasets = build_dataset(set_options=jdata["data_options"]["validation"], common_options=jdata["common_options"])
else:
- trainer = DPTBTrainer(run_opt, jdata)
- trainer.register_plugin(InitDPTBModel())
-
-
+ validation_datasets = None
+ if jdata["data_options"].get("reference"):
+ reference_datasets = build_dataset(set_options=jdata["data_options"]["reference"], common_options=jdata["common_options"])
+ else:
+ reference_datasets = None
+
+ if restart:
+ trainer = Trainer.restart(
+ train_options=jdata["train_options"],
+ common_options=jdata["common_options"],
+ checkpoint=restart,
+ train_datasets=train_datasets,
+ reference_datasets=reference_datasets,
+ validation_datasets=validation_datasets,
+ )
+ else:
+ # include the init model and from scratch
+ # build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint.
+ model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
+ trainer = Trainer(
+ train_options=jdata["train_options"],
+ common_options=jdata["common_options"],
+ model = model,
+ train_datasets=train_datasets,
+ validation_datasets=validation_datasets,
+ reference_datasets=reference_datasets,
+ )
# register the plugin in trainer, to tract training info
- trainer.register_plugin(InitData())
- trainer.register_plugin(Validationer())
+ log_field = ["train_loss", "lr"]
+ if validation_datasets:
+ trainer.register_plugin(Validationer())
+ log_field.append("validation_loss")
trainer.register_plugin(TrainLossMonitor())
trainer.register_plugin(LearningRateMonitor())
- trainer.register_plugin(Logger(["train_loss", "validation_loss", "lr"],
+ trainer.register_plugin(Logger(log_field,
interval=[(jdata["train_options"]["display_freq"], 'iteration'), (1, 'epoch')]))
for q in trainer.plugin_queues.values():
heapq.heapify(q)
-
- trainer.build()
-
if output:
# output training configurations:
with open(os.path.join(output, "train_config.json"), "w") as fp:
- jdata["common_options"]["dtype"] = str_dtype
json.dump(jdata, fp, indent=4)
trainer.register_plugin(Saver(
#interval=[(jdata["train_options"].get("save_freq"), 'epoch'), (1, 'iteration')] if jdata["train_options"].get(
# "save_freq") else None))
interval=[(jdata["train_options"].get("save_freq"), 'iteration'), (1, 'epoch')] if jdata["train_options"].get(
- "save_freq") else None))
+ "save_freq") else None), checkpoint_path=checkpoint_path)
# add a plugin to save the training parameters of the model, with model_output as given path
start_time = time.time()
- trainer.run(trainer.num_epoch)
+ trainer.run(trainer.train_options["num_epoch"])
end_time = time.time()
log.info("finished training")
diff --git a/dptb/hamiltonian/hamil_eig_sk.py b/dptb/hamiltonian/hamil_eig_sk.py
deleted file mode 100644
index d73e7435..00000000
--- a/dptb/hamiltonian/hamil_eig_sk.py
+++ /dev/null
@@ -1,278 +0,0 @@
-import torch
-import torch as th
-import numpy as np
-import logging
-import re
-from dptb.hamiltonian.transform_sk import RotationSK
-from dptb.utils.constants import anglrMId
-
-''' Over use of different index system cause the symbols and type and index kind of object need to be recalculated in different
-Class, this makes entanglement of classes difficult. Need to design an consistent index system to resolve.'''
-
-log = logging.getLogger(__name__)
-
-class HamilEig(RotationSK):
- """ This module is to build the Hamiltonian from the SK-type bond integral.
- """
- def __init__(self, dtype='tensor') -> None:
- super().__init__(rot_type=dtype)
- self.dtype = dtype
- self.use_orthogonal_basis = False
- self.hamil_blocks = None
- self.overlap_blocks = None
-
- def update_hs_list(self, struct, hoppings, onsiteEs, overlaps=None, onsiteSs=None, **options):
- '''It updates the bond structure, bond type, bond type id, bond hopping, bond onsite, hopping, onsite
- energy, overlap, and onsite spin
-
- Parameters
- ----------
- hoppings
- a list bond integral for hoppings.
- onsiteEs
- a list of onsite energy for each atom and each orbital.
- overlaps
- a list bond integral for overlaps.
- onsiteSs
- a list of onsite overlaps for each atom and each orbital.
- '''
- self.__struct__ = struct
- self.hoppings = hoppings
- self.onsiteEs = onsiteEs
- self.use_orthogonal_basis = False
- if overlaps is None:
- self.use_orthogonal_basis = True
- else:
- self.overlaps = overlaps
- self.onsiteSs = onsiteSs
- self.use_orthogonal_basis = False
-
- self.num_orbs_per_atom = []
- for itype in self.__struct__.proj_atom_symbols:
- norbs = self.__struct__.proj_atomtype_norbs[itype]
- self.num_orbs_per_atom.append(norbs)
-
- def get_hs_blocks(self, bonds_onsite = None, bonds_hoppings=None):
- """using the SK type bond integral to build the hamiltonian matrix and overlap matrix in the real space.
-
- The hamiltonian and overlap matrix block are stored in the order of bond list. for ecah bond ij, with lattice
- vecto R, the matrix stored in [norbsi, norbsj]. norsbi and norbsj are the total number of orbtals on i and j sites.
- e.g. for C-atom with both s and p orbital on each site. norbi is 4.
- """
- if bonds_onsite is None:
- bonds_onsite = self.__struct__.__bonds_onsite__
- assert len(bonds_onsite) == len(self.__struct__.__bonds_onsite__)
- if bonds_hoppings is None:
- bonds_hoppings = self.__struct__.__bonds__
- assert len(bonds_hoppings) == len(self.__struct__.__bonds__)
-
- hamil_blocks = []
- if not self.use_orthogonal_basis:
- overlap_blocks = []
- for ib in range(len(bonds_onsite)):
- ibond = bonds_onsite[ib].astype(int)
- iatype = self.__struct__.proj_atom_symbols[ibond[1]]
- jatype = self.__struct__.proj_atom_symbols[ibond[3]]
- assert iatype == jatype, "i type should equal j type."
-
- if self.dtype == 'tensor':
- sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
- if not self.use_orthogonal_basis:
- sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
- else:
- sub_hamil_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
- if not self.use_orthogonal_basis:
- sub_over_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
-
- # ToDo: adding onsite correction
- ist = 0
- # replace sub_hamil_block from now the block diagonal formula to corrected ones.
- for ish in self.__struct__.proj_atom_anglr_m[iatype]: # ['s','p',..]
- ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish))
- shidi = anglrMId[ishsymbol] # 0,1,2,...
- norbi = 2*shidi + 1
-
- indx = self.__struct__.onsite_index_map[iatype][ish] # change onsite index map from {N:{s:}} to {N:{ss:, sp:}}
- # this already satisfy for uniform onsite or splited onsite energy.
- # e.g. for p orbital, index may be 1, or 1,2,3, stands for uniform or splited energy for px py pz.
- # and then self.onsiteEs[ib][indx] can be scalar or torch.Size([1]) or torch.Size([3]).
- # both of them can be transfer into a [3x3] diagonal matrix in this code.
- if self.dtype == 'tensor':
- sub_hamil_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi) * self.onsiteEs[ib][indx]
- if not self.use_orthogonal_basis:
- sub_over_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi) * self.onsiteSs[ib][indx]
- else:
- sub_hamil_block[ist:ist+norbi, ist:ist+norbi] = np.eye(norbi) * self.onsiteEs[ib][indx]
- if not self.use_orthogonal_basis:
- sub_over_block[ist:ist+norbi, ist:ist+norbi] = np.eye(norbi) * self.onsiteSs[ib][indx]
- ist = ist +norbi
-
- hamil_blocks.append(sub_hamil_block)
- if not self.use_orthogonal_basis:
- overlap_blocks.append(sub_over_block)
-
- for ib in range(len(bonds_hoppings)):
-
- ibond = bonds_hoppings[ib,0:7].astype(int)
- #direction_vec = (self.__struct__.projected_struct.positions[ibond[3]]
- # - self.__struct__.projected_struct.positions[ibond[1]]
- # + np.dot(ibond[4:], self.__struct__.projected_struct.cell))
- #dist = np.linalg.norm(direction_vec)
- #direction_vec = direction_vec/dist
- direction_vec = bonds_hoppings[ib,8:11].astype(np.float32)
- iatype = self.__struct__.proj_atom_symbols[ibond[1]]
- jatype = self.__struct__.proj_atom_symbols[ibond[3]]
-
- if self.dtype == 'tensor':
- sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
- if not self.use_orthogonal_basis:
- sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
- else:
- sub_hamil_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
- if not self.use_orthogonal_basis:
- sub_over_block = np.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]])
-
- bondatomtype = iatype + '-' + jatype
-
- ist = 0
- for ish in self.__struct__.proj_atom_anglr_m[iatype]:
- ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish))
- shidi = anglrMId[ishsymbol]
- norbi = 2*shidi+1
-
- jst = 0
- for jsh in self.__struct__.proj_atom_anglr_m[jatype]:
- jshsymbol = ''.join(re.findall(r'[A-Za-z]',jsh))
- shidj = anglrMId[jshsymbol]
- norbj = 2 * shidj + 1
-
- idx = self.__struct__.bond_index_map[bondatomtype][ish+'-'+jsh]
- if shidi < shidj:
- tmpH = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec)
- # Hamilblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1)
- if self.dtype == 'tensor':
- sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpH,dim0=0,dim1=1)
- else:
- sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * np.transpose(tmpH,(1,0))
- if not self.use_orthogonal_basis:
- tmpS = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.overlaps[ib][idx], Angvec=direction_vec)
- # Soverblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpS,dim0=0,dim1=1)
- if self.dtype == 'tensor':
- sub_over_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpS,dim0=0,dim1=1)
- else:
- sub_over_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * np.transpose(tmpS,(1,0))
- else:
- tmpH = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec)
- sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = tmpH
- if not self.use_orthogonal_basis:
- tmpS = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue = self.overlaps[ib][idx], Angvec = direction_vec)
- sub_over_block[ist:ist+norbi, jst:jst+norbj] = tmpS
-
- jst = jst + norbj
- ist = ist + norbi
- hamil_blocks.append(sub_hamil_block)
- if not self.use_orthogonal_basis:
- overlap_blocks.append(sub_over_block)
- self.all_bonds = np.concatenate([bonds_onsite[:,0:7],bonds_hoppings[:,0:7]],axis=0)
- self.all_bonds = self.all_bonds.astype(int)
- self.hamil_blocks = hamil_blocks
- if not self.use_orthogonal_basis:
- self.overlap_blocks = overlap_blocks
-
-
- def hs_block_R2k(self, kpoints, HorS='H', time_symm=True, dtype='tensor'):
- '''The function takes in a list of Hamiltonian matrices for each bond, and a list of k-points, and
- returns a list of Hamiltonian matrices for each k-point
-
- Parameters
- ----------
- HorS
- string, 'H' or 'S' to indicate for Hk or Sk calculation.
- kpoints
- the k-points in the path.
- time_symm, optional
- if True, the Hamiltonian is time-reversal symmetric, defaults to True (optional)
- dtype, optional
- 'tensor' or 'numpy', defaults to tensor (optional)
-
- Returns
- -------
- A list of Hamiltonian or Overlap matrices for each k-point.
- '''
- numOrbs = np.array(self.num_orbs_per_atom)
- totalOrbs = np.sum(numOrbs)
- if HorS == 'H':
- hijAll = self.hamil_blocks
- elif HorS == 'S':
- hijAll = self.overlap_blocks
- else:
- print("HorS should be 'H' or 'S' !")
-
- if dtype == 'tensor':
- Hk = th.zeros([len(kpoints), totalOrbs, totalOrbs], dtype = th.complex64)
- else:
- Hk = np.zeros([len(kpoints), totalOrbs, totalOrbs], dtype = np.complex64)
-
- for ik in range(len(kpoints)):
- k = kpoints[ik]
- if dtype == 'tensor':
- hk = th.zeros([totalOrbs,totalOrbs],dtype = th.complex64)
- else:
- hk = np.zeros([totalOrbs,totalOrbs],dtype = np.complex64)
- for ib in range(len(self.all_bonds)):
- Rlatt = self.all_bonds[ib,4:7].astype(int)
- i = self.all_bonds[ib,1].astype(int)
- j = self.all_bonds[ib,3].astype(int)
- ist = int(np.sum(numOrbs[0:i]))
- ied = int(np.sum(numOrbs[0:i+1]))
- jst = int(np.sum(numOrbs[0:j]))
- jed = int(np.sum(numOrbs[0:j+1]))
- if ib < len(numOrbs):
- """
- len(numOrbs)= numatoms. the first numatoms are onsite energies.
- if turn on timeSymm when generating the bond list . only i>= or <= j are included.
- if turn off timeSymm when generating the bond list . all the i j are included.
- for case 1, H = H+H^\dagger to get the full matrix, the the onsite one is doubled.
- for case 2. no need to do H = H+H^dagger. since the matrix is already full.
- """
- if time_symm:
- hk[ist:ied,jst:jed] += 0.5 * hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt))
- else:
- hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt))
- else:
- hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt))
- if time_symm:
- hk = hk + hk.T.conj()
- Hk[ik] = hk
- return Hk
-
- def Eigenvalues(self, kpoints, time_symm=True,dtype='tensor'):
- """ using the tight-binding H and S matrix calculate eigenvalues at kpoints.
-
- Args:
- kpoints: the k-kpoints used to calculate the eigenvalues.
- Note: must have the BondHBlock and BondSBlock
- """
- hkmat = self.hs_block_R2k(kpoints=kpoints, HorS='H', time_symm=time_symm, dtype=dtype)
- if not self.use_orthogonal_basis:
- skmat = self.hs_block_R2k(kpoints=kpoints, HorS='S', time_symm=time_symm, dtype=dtype)
- else:
- skmat = torch.eye(hkmat.shape[1], dtype=torch.complex64).unsqueeze(0).repeat(hkmat.shape[0], 1, 1)
-
- if self.dtype == 'tensor':
- chklowt = th.linalg.cholesky(skmat)
- chklowtinv = th.linalg.inv(chklowt)
- Heff = (chklowtinv @ hkmat @ th.transpose(chklowtinv,dim0=1,dim1=2).conj())
- # the factor 13.605662285137 * 2 from Hartree to eV.
- # eigks = th.linalg.eigvalsh(Heff) * 13.605662285137 * 2
- eigks, Q = th.linalg.eigh(Heff)
- eigks = eigks * 13.605662285137 * 2
- Qres = Q.detach()
- else:
- chklowt = np.linalg.cholesky(skmat)
- chklowtinv = np.linalg.inv(chklowt)
- Heff = (chklowtinv @ hkmat @ np.transpose(chklowtinv,(0,2,1)).conj())
- eigks = np.linalg.eigvalsh(Heff) * 13.605662285137 * 2
- Qres = 0
-
- return eigks, Qres
\ No newline at end of file
diff --git a/dptb/hamiltonian/hamil_eig_sk_crt_soc.py b/dptb/hamiltonian/hamil_eig_sk_crt_soc.py
deleted file mode 100644
index 19493c18..00000000
--- a/dptb/hamiltonian/hamil_eig_sk_crt_soc.py
+++ /dev/null
@@ -1,395 +0,0 @@
-import torch
-import torch as th
-import numpy as np
-import logging
-import re
-from dptb.hamiltonian.transform_sk import RotationSK
-from dptb.nnsktb.formula import SKFormula
-from dptb.utils.constants import anglrMId
-from dptb.hamiltonian.soc import creat_basis_lm, get_soc_matrix_cubic_basis
-
-''' Over use of different index system cause the symbols and type and index kind of object need to be recalculated in different
-Class, this makes entanglement of classes difficult. Need to design an consistent index system to resolve.'''
-
-log = logging.getLogger(__name__)
-
-class HamilEig(RotationSK):
- """ This module is to build the Hamiltonian from the SK-type bond integral.
- """
- def __init__(self, dtype=torch.float32, device='cpu') -> None:
- super().__init__(rot_type=dtype, device=device)
- self.dtype = dtype
- if self.dtype is th.float32:
- self.cdtype = th.complex64
- elif self.dtype is th.float64:
- self.cdtype = th.complex128
- self.use_orthogonal_basis = False
- self.hamil_blocks = None
- self.overlap_blocks = None
- self.device = device
-
- def update_hs_list(self, struct, hoppings, onsiteEs, onsiteVs=None, overlaps=None, onsiteSs=None, soc_lambdas=None, **options):
- '''It updates the bond structure, bond type, bond type id, bond hopping, bond onsite, hopping, onsite
- energy, overlap, and onsite spin
-
- Parameters
- ----------
- hoppings
- a list bond integral for hoppings.
- onsiteEs
- a list of onsite energy for each atom and each orbital.
- overlaps
- a list bond integral for overlaps.
- onsiteSs
- a list of onsite overlaps for each atom and each orbital.
- '''
- self.__struct__ = struct
- self.hoppings = hoppings
- self.onsiteEs = onsiteEs
- self.onsiteVs = onsiteVs
- self.soc_lambdas = soc_lambdas
- self.use_orthogonal_basis = False
- if overlaps is None:
- self.use_orthogonal_basis = True
- else:
- self.overlaps = overlaps
- self.onsiteSs = onsiteSs
- self.use_orthogonal_basis = False
-
- if soc_lambdas is None:
- self.soc = False
- else:
- self.soc = True
-
- self.num_orbs_per_atom = []
- for itype in self.__struct__.proj_atom_symbols:
- norbs = self.__struct__.proj_atomtype_norbs[itype]
- self.num_orbs_per_atom.append(norbs)
-
- def get_soc_block(self, bonds_onsite = None):
- numOrbs = np.array(self.num_orbs_per_atom)
- totalOrbs = np.sum(numOrbs)
- if bonds_onsite is None:
- _, bonds_onsite = self.__struct__.get_bond()
-
- soc_upup = torch.zeros_like((totalOrbs, totalOrbs), device=self.device, dtype=self.cdtype)
- soc_updown = torch.zeros_like((totalOrbs, totalOrbs), device=self.device, dtype=self.cdtype)
-
- # compute soc mat for each atom:
- soc_atom_upup = self.__struct__.get("soc_atom_diag", {})
- soc_atom_updown = self.__struct__.get("soc_atom_up", {})
- if not soc_atom_upup or not soc_atom_updown:
- for iatype in self.__struct__.proj_atomtype:
- total_num_orbs_iatom= self.__struct__.proj_atomtype_norbs[iatype]
- tmp_upup = torch.zeros([total_num_orbs_iatom, total_num_orbs_iatom], dtype=self.cdtype, device=self.device)
- tmp_updown = torch.zeros([total_num_orbs_iatom, total_num_orbs_iatom], dtype=self.cdtype, device=self.device)
-
- ist = 0
- for ish in self.__struct__.proj_atom_anglr_m[iatype]:
- ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish))
- shidi = anglrMId[ishsymbol] # 0,1,2,...
- norbi = 2*shidi + 1
-
- soc_orb = get_soc_matrix_cubic_basis(orbital=ishsymbol, device=self.device, dtype=self.dtype)
- if len(soc_orb) != 2*norbi:
- log.error(msg='The dimension of the soc_orb is not correct!')
- tmp_upup[ist:ist+norbi, ist:ist+norbi] = soc_orb[:norbi,:norbi]
- tmp_updown[ist:ist+norbi, ist:ist+norbi] = soc_orb[:norbi, norbi:]
- ist = ist + norbi
-
- soc_atom_upup.update({iatype:tmp_upup})
- soc_atom_updown.update({iatype:tmp_updown})
- self.__struct__.soc_atom_upup = soc_atom_upup
- self.__struct__.soc_atom_updown = soc_atom_updown
-
- for ib in range(len(bonds_onsite)):
- ibond = bonds_onsite[ib].astype(int)
- iatom = ibond[1]
- ist = int(np.sum(numOrbs[0:iatom]))
- ied = int(np.sum(numOrbs[0:iatom+1]))
- iatype = self.__struct__.proj_atom_symbols[iatom]
-
- # get lambdas
- ist = 0
- lambdas = torch.zeros((ied-ist,), device=self.device, dtype=self.dtype)
- for ish in self.__struct__.proj_atom_anglr_m[iatype]:
- indx = self.__struct__.onsite_index_map[iatype][ish]
- ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish))
- shidi = anglrMId[ishsymbol] # 0,1,2,...
- norbi = 2*shidi + 1
- lambdas[ist:ist+norbi] = self.soc_lambdas[ib][indx]
- ist = ist + norbi
-
- soc_upup[ist:ied,ist:ied] = soc_atom_upup[iatype] * torch.diag(lambdas)
- soc_updown[ist:ied, ist:ied] = soc_atom_updown[iatype] * torch.diag(lambdas)
-
- soc_upup.contiguous()
- soc_updown.contiguous()
-
- return soc_upup, soc_updown
-
- def get_hs_onsite(self, bonds_onsite = None, onsite_envs=None):
- if bonds_onsite is None:
- _, bonds_onsite = self.__struct__.get_bond()
- onsiteH_blocks = []
- if not self.use_orthogonal_basis:
- onsiteS_blocks = []
- else:
- onsiteS_blocks = None
-
- iatom_to_onsite_index = {}
- for ib in range(len(bonds_onsite)):
- ibond = bonds_onsite[ib].astype(int)
- iatom = ibond[1]
- iatom_to_onsite_index.update({iatom:ib})
- jatom = ibond[3]
- iatype = self.__struct__.proj_atom_symbols[iatom]
- jatype = self.__struct__.proj_atom_symbols[jatom]
- assert iatype == jatype, "i type should equal j type."
-
- sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device)
- if not self.use_orthogonal_basis:
- sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device)
-
- ist = 0
- for ish in self.__struct__.proj_atom_anglr_m[iatype]: # ['s','p',..]
- ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish))
- shidi = anglrMId[ishsymbol] # 0,1,2,...
- norbi = 2*shidi + 1
-
- indx = self.__struct__.onsite_index_map[iatype][ish] # change onsite index map from {N:{s:}} to {N:{ss:, sp:}}
- sub_hamil_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi, dtype=self.dtype, device=self.device) * self.onsiteEs[ib][indx]
- if not self.use_orthogonal_basis:
- sub_over_block[ist:ist+norbi, ist:ist+norbi] = th.eye(norbi, dtype=self.dtype, device=self.device) * self.onsiteSs[ib][indx]
- ist = ist + norbi
-
- onsiteH_blocks.append(sub_hamil_block)
- if not self.use_orthogonal_basis:
- onsiteS_blocks.append(sub_over_block)
-
- # onsite strain
- if onsite_envs is not None:
- assert self.onsiteVs is not None
- for ib, env in enumerate(onsite_envs):
-
- iatype, iatom, jatype, jatom = self.__struct__.proj_atom_symbols[int(env[1])], env[1], self.__struct__.atom_symbols[int(env[3])], env[3]
- direction_vec = env[8:11].astype(np.float32)
-
- sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[iatype]], dtype=self.dtype, device=self.device)
-
- envtype = iatype + '-' + jatype
-
- ist = 0
- for ish in self.__struct__.proj_atom_anglr_m[iatype]:
- ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish))
- shidi = anglrMId[ishsymbol]
- norbi = 2*shidi+1
-
- jst = 0
- for jsh in self.__struct__.proj_atom_anglr_m[iatype]:
- jshsymbol = ''.join(re.findall(r'[A-Za-z]',jsh))
- shidj = anglrMId[jshsymbol]
- norbj = 2 * shidj + 1
-
- idx = self.__struct__.onsite_strain_index_map[envtype][ish+'-'+jsh]
-
- if shidi < shidj:
-
- tmpH = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.onsiteVs[ib][idx], Angvec=direction_vec)
- # Hamilblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1)
- sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1)
- else:
- tmpH = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue=self.onsiteVs[ib][idx], Angvec=direction_vec)
- sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = tmpH
-
- jst = jst + norbj
- ist = ist + norbi
- onsiteH_blocks[iatom_to_onsite_index[iatom]] += sub_hamil_block
-
- return onsiteH_blocks, onsiteS_blocks, bonds_onsite
-
- def get_hs_hopping(self, bonds_hoppings = None):
- if bonds_hoppings is None:
- bonds_hoppings, _ = self.__struct__.get_bond()
-
- hoppingH_blocks = []
- if not self.use_orthogonal_basis:
- hoppingS_blocks = []
- else:
- hoppingS_blocks = None
-
- for ib in range(len(bonds_hoppings)):
-
- ibond = bonds_hoppings[ib,0:7].astype(int)
- #direction_vec = (self.__struct__.projected_struct.positions[ibond[3]]
- # - self.__struct__.projected_struct.positions[ibond[1]]
- # + np.dot(ibond[4:], self.__struct__.projected_struct.cell))
- #dist = np.linalg.norm(direction_vec)
- #direction_vec = direction_vec/dist
- direction_vec = bonds_hoppings[ib,8:11].astype(np.float32)
- iatype = self.__struct__.proj_atom_symbols[ibond[1]]
- jatype = self.__struct__.proj_atom_symbols[ibond[3]]
-
- sub_hamil_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device)
- if not self.use_orthogonal_basis:
- sub_over_block = th.zeros([self.__struct__.proj_atomtype_norbs[iatype], self.__struct__.proj_atomtype_norbs[jatype]], dtype=self.dtype, device=self.device)
-
- bondatomtype = iatype + '-' + jatype
-
- ist = 0
- for ish in self.__struct__.proj_atom_anglr_m[iatype]:
- ishsymbol = ''.join(re.findall(r'[A-Za-z]',ish))
- shidi = anglrMId[ishsymbol]
- norbi = 2*shidi+1
-
- jst = 0
- for jsh in self.__struct__.proj_atom_anglr_m[jatype]:
- jshsymbol = ''.join(re.findall(r'[A-Za-z]',jsh))
- shidj = anglrMId[jshsymbol]
- norbj = 2 * shidj + 1
-
- idx = self.__struct__.bond_index_map[bondatomtype][ish+'-'+jsh]
- if shidi < shidj:
- tmpH = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec)
- # Hamilblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpH,dim0=0,dim1=1)
- sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpH,dim0=0,dim1=1)
- if not self.use_orthogonal_basis:
- tmpS = self.rot_HS(Htype=ishsymbol+jshsymbol, Hvalue=self.overlaps[ib][idx], Angvec=direction_vec)
- # Soverblock[ist:ist+norbi, jst:jst+norbj] = th.transpose(tmpS,dim0=0,dim1=1)
- sub_over_block[ist:ist+norbi, jst:jst+norbj] = (-1.0)**(shidi + shidj) * th.transpose(tmpS,dim0=0,dim1=1)
- else:
- tmpH = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue=self.hoppings[ib][idx], Angvec=direction_vec)
- sub_hamil_block[ist:ist+norbi, jst:jst+norbj] = tmpH
- if not self.use_orthogonal_basis:
- tmpS = self.rot_HS(Htype=jshsymbol+ishsymbol, Hvalue = self.overlaps[ib][idx], Angvec = direction_vec)
- sub_over_block[ist:ist+norbi, jst:jst+norbj] = tmpS
-
- jst = jst + norbj
- ist = ist + norbi
-
- hoppingH_blocks.append(sub_hamil_block)
- if not self.use_orthogonal_basis:
- hoppingS_blocks.append(sub_over_block)
-
- return hoppingH_blocks, hoppingS_blocks, bonds_hoppings
-
- def get_hs_blocks(self, bonds_onsite = None, bonds_hoppings=None, onsite_envs=None):
- onsiteH, onsiteS, bonds_onsite = self.get_hs_onsite(bonds_onsite=bonds_onsite, onsite_envs=onsite_envs)
- hoppingH, hoppingS, bonds_hoppings = self.get_hs_hopping(bonds_hoppings=bonds_hoppings)
-
- self.all_bonds = np.concatenate([bonds_onsite[:,0:7],bonds_hoppings[:,0:7]],axis=0)
- self.all_bonds = self.all_bonds.astype(int)
- onsiteH.extend(hoppingH)
- self.hamil_blocks = onsiteH
- if not self.use_orthogonal_basis:
- onsiteS.extend(hoppingS)
- self.overlap_blocks = onsiteS
- if self.soc:
- self.soc_upup, self.soc_updown = self.get_soc_block(bonds_onsite=bonds_onsite)
-
- return True
-
- def hs_block_R2k(self, kpoints, HorS='H', time_symm=True):
- '''The function takes in a list of Hamiltonian matrices for each bond, and a list of k-points, and
- returns a list of Hamiltonian matrices for each k-point
-
- Parameters
- ----------
- HorS
- string, 'H' or 'S' to indicate for Hk or Sk calculation.
- kpoints
- the k-points in the path.
- time_symm, optional
- if True, the Hamiltonian is time-reversal symmetric, defaults to True (optional)
- dtype, optional
- 'tensor' or 'numpy', defaults to tensor (optional)
-
- Returns
- -------
- A list of Hamiltonian or Overlap matrices for each k-point.
- '''
-
- numOrbs = np.array(self.num_orbs_per_atom)
- totalOrbs = np.sum(numOrbs)
- if HorS == 'H':
- hijAll = self.hamil_blocks
- elif HorS == 'S':
- hijAll = self.overlap_blocks
- else:
- print("HorS should be 'H' or 'S' !")
-
- if self.soc:
- Hk = th.zeros([len(kpoints), 2*totalOrbs, 2*totalOrbs], dtype = self.cdtype, device=self.device)
- else:
- Hk = th.zeros([len(kpoints), totalOrbs, totalOrbs], dtype = self.cdtype, device=self.device)
-
- for ik in range(len(kpoints)):
- k = kpoints[ik]
- hk = th.zeros([totalOrbs,totalOrbs],dtype = self.cdtype, device=self.device)
- for ib in range(len(self.all_bonds)):
- Rlatt = self.all_bonds[ib,4:7].astype(int)
- i = self.all_bonds[ib,1].astype(int)
- j = self.all_bonds[ib,3].astype(int)
- ist = int(np.sum(numOrbs[0:i]))
- ied = int(np.sum(numOrbs[0:i+1]))
- jst = int(np.sum(numOrbs[0:j]))
- jed = int(np.sum(numOrbs[0:j+1]))
- if ib < len(numOrbs):
- """
- len(numOrbs)= numatoms. the first numatoms are onsite energies.
- if turn on timeSymm when generating the bond list . only i>= or <= j are included.
- if turn off timeSymm when generating the bond list . all the i j are included.
- for case 1, H = H+H^\dagger to get the full matrix, the the onsite one is doubled.
- for case 2. no need to do H = H+H^dagger. since the matrix is already full.
- """
- if time_symm:
- hk[ist:ied,jst:jed] += 0.5 * hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt))
- else:
- hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt))
- else:
- hk[ist:ied,jst:jed] += hijAll[ib] * np.exp(-1j * 2 * np.pi* np.dot(k,Rlatt))
- if time_symm:
- hk = hk + hk.T.conj()
- if self.soc:
- hk = torch.kron(A=torch.eye(2, device=self.device, dtype=self.dtype), B=hk)
- Hk[ik] = hk
-
- if self.soc:
- Hk[:, :totalOrbs, :totalOrbs] += self.soc_upup.unsqueeze(0)
- Hk[:, totalOrbs:, totalOrbs:] += self.soc_upup.conj().unsqueeze(0)
- Hk[:, :totalOrbs, totalOrbs:] += self.soc_updown.unsqueeze(0)
- Hk[:, totalOrbs:, :totalOrbs] += self.soc_updown.conj().unsqueeze(0)
-
- Hk.contiguous()
-
- return Hk
-
- def Eigenvalues(self, kpoints, time_symm=True):
- """ using the tight-binding H and S matrix calculate eigenvalues at kpoints.
-
- Args:
- kpoints: the k-kpoints used to calculate the eigenvalues.
- Note: must have the BondHBlock and BondSBlock
- """
- hkmat = self.hs_block_R2k(kpoints=kpoints, HorS='H', time_symm=time_symm)
- if not self.use_orthogonal_basis:
- skmat = self.hs_block_R2k(kpoints=kpoints, HorS='S', time_symm=time_symm)
- else:
- skmat = torch.eye(hkmat.shape[1], dtype=self.cdtype).unsqueeze(0).repeat(hkmat.shape[0], 1, 1)
-
- chklowt = th.linalg.cholesky(skmat)
- chklowtinv = th.linalg.inv(chklowt)
- Heff = (chklowtinv @ hkmat @ th.transpose(chklowtinv,dim0=1,dim1=2).conj())
- # the factor 13.605662285137 * 2 from Hartree to eV.
- # eigks = th.linalg.eigvalsh(Heff) * 13.605662285137 * 2
- eigks, Q = th.linalg.eigh(Heff)
- eigks = eigks * 13.605662285137 * 2
- Qres = Q.detach()
- # else:
- # chklowt = np.linalg.cholesky(skmat)
- # chklowtinv = np.linalg.inv(chklowt)
- # Heff = (chklowtinv @ hkmat @ np.transpose(chklowtinv,(0,2,1)).conj())
- # eigks = np.linalg.eigvalsh(Heff) * 13.605662285137 * 2
- # Qres = 0
-
- return eigks, Qres
\ No newline at end of file
diff --git a/dptb/nn/__init__.py b/dptb/nn/__init__.py
new file mode 100644
index 00000000..4aa0820a
--- /dev/null
+++ b/dptb/nn/__init__.py
@@ -0,0 +1,140 @@
+from .build import build_model
+from .deeptb import DPTB
+from .nnsk import NNSK
+
+__all__ = [
+ build_model,
+ DPTB,
+ NNSK,
+]
+"""
+
+nn module is the model class for the graph neural network model, which is the core of the deeptb package.
+It provide two interfaces which is used to interact with other module:
+1. The build_model method, which is used to construct a model based on the model_options, common_options and run_options.
+ - the model options decides the structure of the model, such as the number of layers, the activation function, the number of neurons in each layer, etc.
+ - the common options contains some common parameters, such as the dtype, device, and the basis, which also related to the model
+ - the run options decide how to initialize the model. Whether it is from scratch, init from checkpoint, freeze or not, or whether to deploy it.
+ The build model method will return a model class and a config dict.
+
+2. A config dict of the model, which contains the essential information of the model to be initialized again.
+
+The build model method should composed of the following steps:
+1. process the configs from user input and the config from the checkpoint (if any).
+2. construct the model based on the configs.
+3. process the config dict for the output dict.
+
+The deeptb model can be constructed by the following steps (which have been conpacted in deeptb.py):
+1. choose the way to construct edge and atom embedding, either a descriptor, GNN or both.
+ - in: data with env and edge vectors, and atomic numbers
+ - out: data with edge and atom embedding
+ - user view: this can be defined as a tag in model_options
+2. constructing the prediction layer, which named as sktb layer or e3tb layer, it is either a linear layer or a neural network
+ - in: data with edge and atom embedding
+ - out: data with the e3 irreducible matrix element, or the sk parameters
+3. constructing hamiltonian model, either a SKTB or E3TB
+ - in: data with properties/parameters predicted
+ - out data with SK/E3 hamiltonian
+
+model_options = {
+ "embedding": {
+ "mode":"se2/gnn/se3...",
+ # mode specific
+ # se2
+ "env_cutoff": 3.5,
+ "rs": float,
+ "rc": float,
+ "n_axis": int,
+ "radial_embedding": {
+ "neurons": [int],
+ "activation": str,
+ "if_batch_normalized": bool
+ }
+ # gnn
+ # se3
+ },
+ "prediction": {
+ "mode": "linear/nn",
+ # linear
+ # nn
+ "neurons": [int],
+ "activation": str,
+ "if_batch_normalized": bool,
+ "hamiltonian" = {
+ "method": "sktb/e3tb",
+ "rmax": 3.5,
+ "precision": float, # use to check if rmax is large enough
+ "soc": bool,
+ "overlap": bool,
+ # sktb
+ # e3tb
+ },
+ },
+ "nnsk":{
+ "hopping_function": {
+ "formula": "varTang96/powerlaw/NRL",
+ ...
+ },
+ "onsite_function": {
+ "formula": "strain/uniform/NRL",
+ # strain
+ "strain_cutoff": float,
+ # NRL
+ "cutoff": float,
+ "decay_w": float,
+ "lambda": float
+ }
+ }
+}
+"""
+
+common_options = {
+ "basis": {
+ "B": "2s2p1d",
+ "N": "2s2p1d",
+ },
+ "device": "cpu",
+ "dtype": "float32",
+ "r_max": 2.0,
+ "er_max": 4.0,
+ "oer_max": 6.0,
+}
+
+
+data_options = {
+ "train": {
+
+ }
+}
+
+
+dptb_model_options = {
+ "embedding": {
+ "method": "se2",
+ "rs": 2.0,
+ "rc": 7.0,
+ "n_axis": 10,
+ "radial_embedding": {
+ "neurons": [128,128,20],
+ "activation": "tanh",
+ "if_batch_normalized": False,
+ },
+ },
+ "prediction":{
+ "method": "nn",
+ "neurons": [256,256,256],
+ "activation": "tanh",
+ "if_batch_normalized": False,
+ "quantities": ["hamiltonian"],
+ "hamiltonian":{
+ "method": "e3tb",
+ "precision": 1e-5,
+ "overlap": False,
+ },
+ },
+ "nnsk": {
+ "onsite": {"method": "strain", "rs":6.0, "w":0.1},
+ "hopping": {"method": "powerlaw", "rs":3.2, "w": 0.15},
+ "overlap": False
+ }
+}
\ No newline at end of file
diff --git a/dptb/nn/base.py b/dptb/nn/base.py
new file mode 100644
index 00000000..a11e0172
--- /dev/null
+++ b/dptb/nn/base.py
@@ -0,0 +1,464 @@
+from torch.nn import Linear
+import torch
+from dptb.data import AtomicDataDict
+from typing import Optional, Any, Union, Callable, OrderedDict, List
+from torch import Tensor
+from dptb.utils.constants import dtype_dict
+from dptb.utils.tools import _get_activation_fn
+import torch.nn.functional as F
+import torch.nn as nn
+
+class AtomicLinear(torch.nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ in_field = AtomicDataDict.NODE_FEATURES_KEY,
+ out_field = AtomicDataDict.NODE_FEATURES_KEY,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu")
+ ):
+ super(AtomicLinear, self).__init__()
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+ self.linear = Linear(in_features, out_features, dtype=dtype, device=device)
+ self.in_field = in_field
+ self.out_field = out_field
+
+ def forward(self, data: AtomicDataDict.Type):
+ data[self.out_field] = self.linear(data[self.in_field])
+ return data
+
+class Identity(torch.nn.Module):
+ def __init__(
+ self,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+ super(Identity, self).__init__()
+
+ def forward(self, data: AtomicDataDict) -> AtomicDataDict:
+ return data
+
+
+class AtomicMLP(torch.nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features,
+ out_features,
+ in_field = AtomicDataDict.NODE_FEATURES_KEY,
+ out_field = AtomicDataDict.NODE_FEATURES_KEY,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ if_batch_normalized: bool = False,
+ device: Union[str, torch.device] = torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32
+ ):
+ super(AtomicMLP, self).__init__()
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+ self.in_layer = Linear(
+ in_features=in_features,
+ out_features=hidden_features,
+ device=device,
+ dtype=dtype)
+
+ self.out_layer = Linear(
+ in_features=hidden_features,
+ out_features=out_features,
+ device=device,
+ dtype=dtype)
+
+ if if_batch_normalized:
+ self.bn1 = torch.nn.BatchNorm1d(hidden_features)
+ self.bn2 = torch.nn.BatchNorm1d(out_features)
+ self.if_batch_normalized = if_batch_normalized
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ self.in_field = in_field
+ self.out_field = out_field
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.relu
+ super(AtomicMLP, self).__setstate__(state)
+
+ def forward(self, data: AtomicDataDict.Type):
+ x = self.in_layer(data[self.in_field])
+ if self.if_batch_normalized:
+ x = self.bn1(x)
+ x = self.activation(x)
+ x = self.out_layer(x)
+ if self.if_batch_normalized:
+ x = self.bn2(x)
+ data[self.out_field] = x
+
+ return data
+
+class AtomicFFN(torch.nn.Module):
+ def __init__(
+ self,
+ config: List[dict],
+ in_field: AtomicDataDict.NODE_FEATURES_KEY,
+ out_field: AtomicDataDict.NODE_FEATURES_KEY,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ if_batch_normalized: bool = False,
+ device: Union[str, torch.device] = torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ **kwargs
+ ):
+ super(AtomicFFN, self).__init__()
+ self.layers = torch.nn.ModuleList([])
+ for kk in range(len(config)-1):
+ if kk == 0:
+ self.layers.append(
+ AtomicMLP(
+ **config[kk],
+ in_field=in_field,
+ out_field=out_field,
+ if_batch_normalized=if_batch_normalized,
+ activation=activation,
+ device=device,
+ dtype=dtype
+ )
+ )
+ else:
+ self.layers.append(
+ AtomicMLP(
+ **config[kk],
+ in_field=out_field,
+ out_field=out_field,
+ if_batch_normalized=if_batch_normalized,
+ activation=activation,
+ device=device,
+ dtype=dtype
+ )
+ )
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ if config[-1].get('hidden_features') is None:
+ self.out_layer = AtomicLinear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], in_field=out_field, out_field=out_field, device=device, dtype=dtype)
+ else:
+ self.out_layer = AtomicMLP(**config[-1], in_field=out_field, out_field=out_field, if_batch_normalized=False, activation=activation, device=device, dtype=dtype)
+ self.out_field = out_field
+ self.in_field = in_field
+ # self.out_norm = nn.LayerNorm(config[-1]['out_features'], elementwise_affine=True)
+
+ def forward(self, data: AtomicDataDict.Type):
+ out_scale = self.out_scale(data[self.in_field])
+ out_shift = self.out_shift(data[self.in_field])
+ for layer in self.layers:
+ data = layer(data)
+ data[self.out_field] = self.activation(data[self.out_field])
+
+ data = self.out_layer(data)
+ # data[self.out_field] = self.out_norm(data[self.out_field])
+ return data
+
+
+class AtomicResBlock(torch.nn.Module):
+ def __init__(self,
+ in_features: int,
+ hidden_features: int,
+ out_features: int,
+ in_field = AtomicDataDict.NODE_FEATURES_KEY,
+ out_field = AtomicDataDict.NODE_FEATURES_KEY,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ if_batch_normalized: bool=False,
+ device: Union[str, torch.device] = torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32
+ ):
+
+ super(AtomicResBlock, self).__init__()
+ self.in_field = in_field
+ self.out_field = out_field
+ self.layer = AtomicMLP(in_features, hidden_features, out_features, in_field=in_field, out_field=out_field, if_batch_normalized=if_batch_normalized, device=device, dtype=dtype, activation=activation)
+ self.out_features = out_features
+ self.in_features = in_features
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.relu
+ super(AtomicResBlock, self).__setstate__(state)
+
+ def forward(self, data: AtomicDataDict.Type):
+ if self.in_features < self.out_features:
+ res = F.interpolate(data[self.in_field].unsqueeze(1), size=[self.out_features]).squeeze(1)
+ elif self.in_features == self.out_features:
+ res = data[self.in_field]
+ else:
+ res = F.adaptive_avg_pool1d(input=data[self.in_field], output_size=self.out_features)
+
+ data = self.layer(data)
+ data[self.out_field] = data[self.out_field] + res
+
+ data[self.out_field] = self.activation(data[self.out_field])
+
+ return data
+
+# The ResNet class is a neural network model that consists of multiple residual blocks and a final
+# output layer, with options for activation functions and batch normalization.
+
+class AtomicResNet(torch.nn.Module):
+ def __init__(
+ self,
+ config: List[dict],
+ in_field: AtomicDataDict.NODE_FEATURES_KEY,
+ out_field: AtomicDataDict.NODE_FEATURES_KEY,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ if_batch_normalized: bool = False,
+ device: Union[str, torch.device] = torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ **kwargs,
+ ):
+ """_summary_
+
+ Parameters
+ ----------
+ config : list
+ ep: config = [
+ {'in_features': 3, 'hidden_features': 4, 'out_features': 8},
+ {'in_features': 8, 'hidden_features': 6, 'out_features': 4}
+ ]
+ activation : _type_
+ _description_
+ if_batch_normalized : bool, optional
+ _description_, by default False
+ device : str, optional
+ _description_, by default 'cpu'
+ dtype : _type_, optional
+ _description_, by default torch.float32
+ """
+ super(AtomicResNet, self).__init__()
+ self.in_field = in_field
+ self.out_field = out_field
+ self.layers = torch.nn.ModuleList([])
+ for kk in range(len(config)-1):
+ # the first layer will take the in_field as key to take `data[in_field]` and output the out_field, data[out_field] = layer(data[in_field])
+ # the rest of the layers will take the out_field as key to take `data[out_field]` and output the out_field, data[out_field] = layer(data[out_field])
+ # That why we need to set the in_field and out_field for 1st layer and the rest of the layers.
+ if kk == 0:
+ self.layers.append(
+ AtomicResBlock(
+ **config[kk],
+ in_field=in_field,
+ out_field=out_field,
+ if_batch_normalized=if_batch_normalized,
+ activation=activation,
+ device=device,
+ dtype=dtype
+ )
+ )
+ else:
+ self.layers.append(
+ AtomicResBlock(
+ **config[kk],
+ in_field=out_field,
+ out_field=out_field,
+ if_batch_normalized=if_batch_normalized,
+ activation=activation,
+ device=device,
+ dtype=dtype
+ )
+ )
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+
+ if config[-1].get('hidden_feature') is None:
+ self.out_layer = AtomicLinear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], in_field=out_field, out_field=out_field, device=device, dtype=dtype)
+ else:
+ self.out_layer = AtomicMLP(**config[-1], if_batch_normalized=False, in_field=in_field, out_field=out_field, activation=activation, device=device, dtype=dtype)
+ # self.out_norm = nn.LayerNorm(config[-1]['out_features'], elementwise_affine=True)
+
+ def forward(self, data: AtomicDataDict.Type):
+
+ for layer in self.layers:
+ data = layer(data)
+ data[self.out_field] = self.activation(data[self.out_field])
+ data = self.out_layer(data)
+ # data[self.out_field] = self.out_norm(data[self.out_field])
+ return data
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features,
+ out_features,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ if_batch_normalized=False,
+ device: Union[str, torch.device]=torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ ):
+ super(MLP, self).__init__()
+
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ self.in_layer = nn.Linear(in_features=in_features, out_features=hidden_features, device=device, dtype=dtype)
+ self.out_layer = nn.Linear(in_features=hidden_features, out_features=out_features, device=device, dtype=dtype)
+
+ if if_batch_normalized:
+ self.bn1 = nn.BatchNorm1d(hidden_features)
+ self.bn2 = nn.BatchNorm1d(out_features)
+ self.if_batch_normalized = if_batch_normalized
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.relu
+ super(MLP, self).__setstate__(state)
+
+ def forward(self, x):
+ x = self.in_layer(x)
+ if self.if_batch_normalized:
+ x = self.bn1(x)
+ x = self.activation(x)
+ x = self.out_layer(x)
+ if self.if_batch_normalized:
+ x = self.bn2(x)
+
+ return x
+
+class FFN(nn.Module):
+ def __init__(
+ self,
+ config,
+ activation,
+ if_batch_normalized=False,
+ device: Union[str, torch.device]=torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ **kwargs
+ ):
+ super(FFN, self).__init__()
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ self.layers = nn.ModuleList([])
+ for kk in range(len(config)-1):
+ self.layers.append(MLP(**config[kk], if_batch_normalized=if_batch_normalized, activation=activation, device=device, dtype=dtype))
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ if config[-1].get('hidden_features') is None:
+ self.out_layer = nn.Linear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], device=device, dtype=dtype)
+ # nn.init.normal_(self.out_layer.weight, mean=0, std=1e-3)
+ # nn.init.normal_(self.out_layer.bias, mean=0, std=1e-3)
+ else:
+ self.out_layer = MLP(**config[-1], if_batch_normalized=False, activation=activation, device=device, dtype=dtype)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ x = self.activation(x)
+
+ return self.out_layer(x)
+
+
+class ResBlock(torch.nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features,
+ out_features,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ if_batch_normalized=False,
+ device: Union[str, torch.device]=torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ ):
+ super(ResBlock, self).__init__()
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ self.layer = MLP(in_features, hidden_features, out_features, if_batch_normalized=if_batch_normalized, device=device, dtype=dtype, activation=activation)
+ self.out_features = out_features
+ self.in_features = in_features
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ def __setstate__(self, state):
+ pass
+ # super(ResBlock, self).__setstate__(state)
+
+ def forward(self, x):
+ out = self.layer(x)
+ if self.in_features < self.out_features:
+ out = nn.functional.interpolate(x.unsqueeze(1), size=[self.out_features]).squeeze(1) + out
+ elif self.in_features == self.out_features:
+ out = x + out
+ else:
+ out = nn.functional.adaptive_avg_pool1d(input=x, output_size=self.out_features) + out
+
+ out = self.activation(out)
+
+ return out
+
+class ResNet(torch.nn.Module):
+ def __init__(
+ self,
+ config,
+ activation,
+ if_batch_normalized=False,
+ device: Union[str, torch.device]=torch.device('cpu'),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ **kwargs
+ ):
+ super(ResNet, self).__init__()
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ self.layers = torch.nn.ModuleList([])
+ for kk in range(len(config)-1):
+ self.layers.append(ResBlock(**config[kk], if_batch_normalized=if_batch_normalized, activation=activation, device=device, dtype=dtype))
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+
+ if config[-1].get('hidden_features') is None:
+ self.out_layer = nn.Linear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], device=device, dtype=dtype)
+ # nn.init.normal_(self.out_layer.weight, mean=0, std=1e-3)
+ # nn.init.normal_(self.out_layer.bias, mean=0, std=1e-3)
+ else:
+ self.out_layer = MLP(**config[-1], if_batch_normalized=False, activation=activation, device=device, dtype=dtype)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ x = self.activation(x)
+
+ return self.out_layer(x)
\ No newline at end of file
diff --git a/dptb/nn/build.py b/dptb/nn/build.py
new file mode 100644
index 00000000..3ab9fa46
--- /dev/null
+++ b/dptb/nn/build.py
@@ -0,0 +1,101 @@
+from dptb.nn.deeptb import DPTB, MIX
+import logging
+from dptb.nn.nnsk import NNSK
+import torch
+from dptb.utils.tools import j_must_have
+
+log = logging.getLogger(__name__)
+
+def build_model(run_options, model_options: dict={}, common_options: dict={}, statistics: dict=None):
+ """
+ The build model method should composed of the following steps:
+ 1. process the configs from user input and the config from the checkpoint (if any).
+ 2. construct the model based on the configs.
+ 3. process the config dict for the output dict.
+ run_opt = {
+ "init_model": init_model,
+ "restart": restart,
+ "freeze": freeze,
+ "log_path": log_path,
+ "log_level": log_level,
+ "use_correction": use_correction
+ }
+ """
+ # this is the
+ # process the model_options
+ assert not all((run_options.get("init_model"), run_options.get("restart"))), "You can only choose one of the init_model and restart options."
+ if any((run_options.get("init_model"), run_options.get("restart"))):
+ from_scratch = False
+ checkpoint = run_options.get("init_model") or run_options.get("restart")
+ else:
+ from_scratch = True
+ if not all((model_options, common_options)):
+ logging.error("You need to provide model_options and common_options when you are initializing a model from scratch.")
+ raise ValueError
+
+ # decide whether to initialize a mixed model, or a deeptb model, or a nnsk model
+ init_deeptb = False
+ init_nnsk = False
+ init_mixed = False
+
+ # load the model_options and common_options from checkpoint if not provided
+ if not from_scratch:
+ # init model from checkpoint
+ if len(model_options) == 0:
+ f = torch.load(checkpoint)
+ model_options = f["config"]["model_options"]
+ del f
+
+ if len(common_options) == 0:
+ f = torch.load(checkpoint)
+ common_options = f["config"]["common_options"]
+ del f
+
+ if all((all((model_options.get("embedding"), model_options.get("prediction"))), model_options.get("nnsk"))):
+ init_mixed = True
+ elif all((model_options.get("embedding"), model_options.get("prediction"))):
+ init_deeptb = True
+ elif model_options.get("nnsk"):
+ init_nnsk = True
+ else:
+ log.error("Model cannot be built without either one of the terms in model_options (embedding+prediction/nnsk).")
+ raise ValueError
+
+ assert int(init_mixed) + int(init_deeptb) + int(init_nnsk) == 1, "You can only choose one of the mixed, deeptb, and nnsk options."
+ # check if the model is deeptb or nnsk
+
+ # init deeptb
+ if from_scratch:
+ if init_deeptb:
+ model = DPTB(**model_options, **common_options)
+
+ # do initialization from statistics if DPTB is e3tb and statistics is provided
+ if model.method == "e3tb" and statistics is not None:
+ scalar_mask = torch.BoolTensor([ir.dim==1 for ir in model.idp.orbpair_irreps])
+ node_shifts = statistics["node"]["scalar_ave"]
+ node_scales = statistics["node"]["norm_ave"]
+ node_scales[:,scalar_mask] = statistics["node"]["scalar_std"]
+
+ edge_shifts = statistics["edge"]["scalar_ave"]
+ edge_scales = statistics["edge"]["norm_ave"]
+ edge_scales[:,scalar_mask] = statistics["edge"]["scalar_std"]
+ model.node_prediction_h.set_scale_shift(scales=node_scales, shifts=node_shifts)
+ model.edge_prediction_h.set_scale_shift(scales=edge_scales, shifts=edge_shifts)
+
+ if init_nnsk:
+ model = NNSK(**model_options["nnsk"], **common_options)
+
+ if init_mixed:
+ model = MIX(**model_options, **common_options)
+
+ else:
+ # load the model from the checkpoint
+ if init_deeptb:
+ model = DPTB.from_reference(checkpoint, **model_options, **common_options)
+ if init_nnsk:
+ model = NNSK.from_reference(checkpoint, **model_options["nnsk"], **common_options)
+ if init_mixed:
+ # mix model can be initilized with a mixed reference model or a nnsk model.
+ model = MIX.from_reference(checkpoint, **model_options, **common_options)
+
+ return model
diff --git a/dptb/nn/cutoff.py b/dptb/nn/cutoff.py
new file mode 100644
index 00000000..76b78014
--- /dev/null
+++ b/dptb/nn/cutoff.py
@@ -0,0 +1,55 @@
+import math
+import torch
+
+
+@torch.jit.script
+def cosine_cutoff(x: torch.Tensor, r_max: torch.Tensor, r_start_cos_ratio: float = 0.8):
+ """A piecewise cosine cutoff starting the cosine decay at r_decay_factor*r_max.
+
+ Broadcasts over r_max.
+ """
+ r_max, x = torch.broadcast_tensors(r_max.unsqueeze(-1), x.unsqueeze(0))
+ r_decay: torch.Tensor = r_start_cos_ratio * r_max
+ # for x < r_decay, clamps to 1, for x > r_max, clamps to 0
+ x = x.clamp(r_decay, r_max)
+ return 0.5 * (torch.cos((math.pi / (r_max - r_decay)) * (x - r_decay)) + 1.0)
+
+
+@torch.jit.script
+def polynomial_cutoff(
+ x: torch.Tensor, r_max: torch.Tensor, p: float = 6.0
+) -> torch.Tensor:
+ """Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
+
+
+ Parameters
+ ----------
+ r_max : tensor
+ Broadcasts over r_max.
+
+ p : int
+ Power used in envelope function
+ """
+ assert p >= 2.0
+ r_max, x = torch.broadcast_tensors(r_max.unsqueeze(-1), x.unsqueeze(0))
+ x = x / r_max
+
+ out = 1.0
+ out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p))
+ out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0))
+ out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0))
+
+ return out * (x < 1.0)
+
+@torch.jit.script
+def polynomial_cutoff2(
+ r: torch.Tensor, rc: torch.Tensor, rs: torch.Tensor,
+) -> torch.Tensor:
+
+ r_ = torch.zeros_like(r)
+ r_[r [{'in_features': 1, 'hidden_features': 2, 'out_features': 3},
+ {'in_features': 3, 'hidden_features': 4, 'out_features': 5},
+ {'in_features': 5, 'out_features': 6}]
+ [1, 2, 3, 4, 5] -> [{'in_features': 1, 'hidden_features': 2, 'out_features': 3},
+ {'in_features': 3, 'hidden_features': 4, 'out_features': 5}]
+ """
+
+ n = len(nl)
+ assert n > 1, "The neuron config should have at least 2 layers."
+ if n % 2 == 0:
+ d_out = nl[-1]
+ nl = nl[:-1]
+ config = []
+ for i in range(1,len(nl)-1, 2):
+ config.append({'in_features': nl[i-1], 'hidden_features': nl[i], 'out_features': nl[i+1]})
+
+ if n % 2 == 0:
+ config.append({'in_features': nl[-1], 'out_features': d_out})
+
+ return config
+
+class DPTB(nn.Module):
+ quantities = ["hamiltonian", "energy"]
+ name = "dptb"
+ def __init__(
+ self,
+ embedding: dict,
+ prediction: dict,
+ overlap: bool = False,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ transform: bool = True,
+ **kwargs,
+ ):
+
+ """The top level DeePTB model class.
+
+ Parameters
+ ----------
+ embedding_config : dict
+ _description_
+ prediction_config : dict
+ _description_
+ basis : Dict[str, Union[str, list], None], optional
+ _description_, by default None
+ idp : Union[OrbitalMapper, None], optional
+ _description_, by default None
+ transform : bool, optional
+ _description_, decide whether to transform the irreducible matrix element to the hamiltonians
+ dtype : Union[str, torch.dtype], optional
+ _description_, by default torch.float32
+ device : Union[str, torch.device], optional
+ _description_, by default torch.device("cpu")
+
+ Raises
+ ------
+ NotImplementedError
+ _description_
+ """
+ super(DPTB, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ self.device = device
+ self.model_options = {"embedding": embedding.copy(), "prediction": prediction.copy()}
+ self.transform = transform
+
+
+ self.method = prediction.get("method", "e3tb")
+ # self.soc = prediction.get("soc", False)
+ self.prediction = prediction
+
+ prediction_copy = prediction.copy()
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method=self.method, device=self.device)
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.idp.get_orbpair_maps()
+
+ n_species = len(self.basis.keys())
+ # initialize the embedding layer
+ self.embedding = Embedding(**embedding, dtype=dtype, device=device, idp=self.idp, n_atom=n_species)
+
+ # initialize the prediction layer
+
+ if self.method == "sktb":
+ prediction_copy["neurons"] = [self.embedding.out_node_dim] + prediction_copy["neurons"] + [self.idp.n_onsite_Es]
+ prediction_copy["config"] = get_neuron_config(prediction_copy["neurons"])
+
+ self.node_prediction_h = AtomicResNet(
+ **prediction_copy,
+ in_field=AtomicDataDict.NODE_FEATURES_KEY,
+ out_field=AtomicDataDict.NODE_FEATURES_KEY,
+ device=device,
+ dtype=dtype
+ )
+
+ prediction_copy["neurons"][0] = self.embedding.out_edge_dim
+ prediction_copy["neurons"][-1] = self.idp.reduced_matrix_element
+ prediction_copy["config"] = get_neuron_config(prediction_copy["neurons"])
+ self.edge_prediction_h = AtomicResNet(
+ **prediction_copy,
+ in_field=AtomicDataDict.EDGE_FEATURES_KEY,
+ out_field=AtomicDataDict.EDGE_FEATURES_KEY,
+ device=device,
+ dtype=dtype
+ )
+
+ if overlap:
+ self.edge_prediction_s = AtomicResNet(
+ **prediction_copy,
+ in_field=AtomicDataDict.EDGE_OVERLAP_KEY,
+ out_field=AtomicDataDict.EDGE_OVERLAP_KEY,
+ device=device,
+ dtype=dtype
+ )
+
+ elif prediction_copy.get("method") == "e3tb":
+ self.node_prediction_h = E3PerSpeciesScaleShift(
+ field=AtomicDataDict.NODE_FEATURES_KEY,
+ num_types=n_species,
+ irreps_in=self.embedding.out_node_irreps,
+ out_field = AtomicDataDict.NODE_FEATURES_KEY,
+ shifts=0.,
+ scales=1.,
+ dtype=self.dtype,
+ device=self.device,
+ **prediction_copy,
+ )
+
+ self.edge_prediction_h = E3PerEdgeSpeciesScaleShift(
+ field=AtomicDataDict.EDGE_FEATURES_KEY,
+ num_types=n_species,
+ irreps_in=self.embedding.out_edge_irreps,
+ out_field = AtomicDataDict.EDGE_FEATURES_KEY,
+ shifts=0.,
+ scales=1.,
+ dtype=self.dtype,
+ device=self.device,
+ **prediction_copy,
+ )
+ if overlap:
+ raise NotImplementedError("The overlap prediction is not implemented for e3tb method.")
+
+ else:
+ raise NotImplementedError("The prediction model {} is not implemented.".format(prediction_copy["method"]))
+
+
+ if self.method == "sktb":
+ self.hamiltonian = SKHamiltonian(
+ edge_field=AtomicDataDict.EDGE_FEATURES_KEY,
+ node_field=AtomicDataDict.NODE_FEATURES_KEY,
+ idp_sk=self.idp,
+ dtype=self.dtype,
+ device=self.device,
+ onsite=True,
+ )
+ if overlap:
+ self.overlap = SKHamiltonian(
+ idp_sk=self.idp,
+ edge_field=AtomicDataDict.EDGE_OVERLAP_KEY,
+ node_field=AtomicDataDict.NODE_OVERLAP_KEY,
+ dtype=self.dtype,
+ device=self.device,
+ onsite=False,
+ )
+
+ elif self.method == "e3tb":
+ self.hamiltonian = E3Hamiltonian(
+ edge_field=AtomicDataDict.EDGE_FEATURES_KEY,
+ node_field=AtomicDataDict.NODE_FEATURES_KEY,
+ idp=self.idp,
+ dtype=self.dtype,
+ device=self.device
+ )
+ if overlap:
+ self.overlap = E3Hamiltonian(
+ idp=self.idp,
+ edge_field=AtomicDataDict.EDGE_OVERLAP_KEY,
+ node_field=AtomicDataDict.NODE_OVERLAP_KEY,
+ dtype=self.dtype,
+ device=self.device,
+ overlap=True,
+ )
+
+
+ def forward(self, data: AtomicDataDict.Type):
+
+ data = self.embedding(data)
+ if hasattr(self, "overlap"):
+ data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY]
+
+ data = self.node_prediction_h(data)
+ data = self.edge_prediction_h(data)
+ if hasattr(self, "overlap"):
+ data = self.edge_prediction_s(data)
+
+ if self.transform:
+ data = self.hamiltonian(data)
+ if hasattr(self, "overlap"):
+ data = self.overlap(data)
+
+ return data
+
+ @classmethod
+ def from_reference(
+ cls,
+ checkpoint,
+ embedding: dict={},
+ prediction: dict={},
+ overlap: bool=None,
+ basis: Dict[str, Union[str, list]]=None,
+ dtype: Union[str, torch.dtype]=None,
+ device: Union[str, torch.device]=None,
+ transform: bool = True,
+ **kwargs
+ ):
+
+ ckpt = torch.load(checkpoint)
+ common_options = {
+ "dtype": dtype,
+ "device": device,
+ "basis": basis,
+ "overlap": overlap,
+ }
+
+ model_options = {
+ "embedding": embedding,
+ "prediction": prediction,
+ }
+
+ if len(embedding) == 0 or len(prediction) == 0:
+ model_options.update(ckpt["config"]["model_options"])
+
+ for k,v in common_options.items():
+ if v is None:
+ common_options[k] = ckpt["config"]["common_options"][k]
+
+ model = cls(**model_options, **common_options, transform=transform)
+ model.load_state_dict(ckpt["model_state_dict"])
+
+ del ckpt
+
+ return model
+
+class MIX(nn.Module):
+ name = "mix"
+ def __init__(
+ self,
+ embedding: dict,
+ prediction: dict,
+ nnsk: dict,
+ basis: Dict[str, Union[str, list]]=None,
+ overlap: bool = False,
+ idp_sk: Union[OrbitalMapper, None]=None,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+ super(MIX, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+
+ self.dtype = dtype
+ self.device = device
+
+ self.dptb = DPTB(
+ embedding=embedding,
+ prediction=prediction,
+ basis=basis,
+ idp=idp_sk,
+ overlap=overlap,
+ dtype=dtype,
+ device=device,
+ transform=False,
+ )
+
+ self.nnsk = NNSK(
+ basis=basis,
+ idp_sk=idp_sk,
+ **nnsk,
+ overlap=overlap,
+ dtype=dtype,
+ device=device,
+ transform=False,
+ )
+ self.idp = self.nnsk.idp
+
+ self.model_options = self.nnsk.model_options
+ self.model_options.update(self.dptb.model_options)
+
+ self.hamiltonian = self.nnsk.hamiltonian
+ if overlap:
+ self.overlap = self.nnsk.overlap
+
+
+
+ def forward(self, data: AtomicDataDict.Type):
+ data_dptb = self.dptb(data)
+ data_nnsk = self.nnsk(data)
+
+ data_nnsk[AtomicDataDict.EDGE_FEATURES_KEY] = data_nnsk[AtomicDataDict.EDGE_FEATURES_KEY] * (1 + data_dptb[AtomicDataDict.EDGE_FEATURES_KEY])
+ data_nnsk[AtomicDataDict.NODE_FEATURES_KEY] = data_nnsk[AtomicDataDict.NODE_FEATURES_KEY] * (1 + data_dptb[AtomicDataDict.NODE_FEATURES_KEY])
+
+ data_nnsk = self.hamiltonian(data_nnsk)
+ if hasattr(self, "overlap"):
+ data_nnsk = self.overlap(data_nnsk)
+
+ return data_nnsk
+
+ @classmethod
+ def from_reference(
+ cls,
+ checkpoint,
+ embedding: dict=None,
+ prediction: dict=None,
+ nnsk: dict=None,
+ basis: Dict[str, Union[str, list]]=None,
+ overlap: bool = None,
+ dtype: Union[str, torch.dtype] = None,
+ device: Union[str, torch.device] = None,
+ **kwargs,
+ ):
+ # the mapping from the parameters of the ref_model and the current model can be found using
+ # reference model's idp and current idp
+
+ ckpt = torch.load(checkpoint)
+ common_options = {
+ "dtype": dtype,
+ "device": device,
+ "basis": basis,
+ "overlap": overlap,
+ }
+ model_options = {
+ "embedding": embedding,
+ "prediction": prediction,
+ "nnsk": nnsk,
+ }
+
+ if len(nnsk) == 0:
+ model_options["nnsk"] = ckpt["config"]["model_options"]["nnsk"]
+
+ if len(embedding) == 0 or len(prediction) == 0:
+ assert ckpt["config"]["model_options"].get("embedding") is not None and ckpt["config"]["model_options"].get("prediction") is not None, \
+ "The reference model checkpoint should come from a mixed model if dptb info is not provided."
+
+ model_options["embedding"] = ckpt["config"]["model_options"]["embedding"]
+ model_options["prediction"] = ckpt["config"]["model_options"]["prediction"]
+
+ for k,v in common_options.items():
+ if v is None:
+ common_options[k] = ckpt["config"]["common_options"][k]
+
+ if ckpt["config"]["model_options"].get("embedding") is not None and ckpt["config"]["model_options"].get("prediction") is not None:
+ # read from mixed model
+ model = cls(**model_options, **common_options)
+ model.load_state_dict(ckpt["model_state_dict"])
+
+ else:
+ assert ckpt["config"]["model_options"].get("nnsk") is not None, "The referenced checkpoint should provide at least the nnsk model info."
+ # read from nnsk model
+
+ model = cls(**model_options, **common_options)
+ model.nnsk.load_state_dict(ckpt["model_state_dict"])
+
+ del ckpt
+
+ return model
\ No newline at end of file
diff --git a/dptb/nn/embedding/__init__.py b/dptb/nn/embedding/__init__.py
new file mode 100644
index 00000000..f271c7b6
--- /dev/null
+++ b/dptb/nn/embedding/__init__.py
@@ -0,0 +1,21 @@
+from .emb import Embedding
+from .se2 import SE2Descriptor
+from .baseline import BASELINE
+from .mpnn import MPNN
+from .deephe3 import E3DeePH
+from .e3baseline import E3BaseLineModel
+from .e3baseline_local import E3BaseLineModelLocal
+from .e3baseline_nonlocal import E3BaseLineModelNonLocal
+from .e3baseline_local1 import E3BaseLineModelLocal1
+from .e3baseline_nonlocal_wnode import E3BaseLineModelNonLocalWNODE
+
+__all__ = [
+ "Descriptor",
+ "SE2Descriptor",
+ "Identity",
+ "E3DeePH",
+ "E3BaseLineModelLocal",
+ "E3BaseLineModelLocal1",
+ "E3BaseLineModelNonLocal",
+ "E3BaseLineModelNonLocalWNODE",
+]
\ No newline at end of file
diff --git a/dptb/nn/embedding/allegro.py b/dptb/nn/embedding/allegro.py
new file mode 100644
index 00000000..05ea2ee6
--- /dev/null
+++ b/dptb/nn/embedding/allegro.py
@@ -0,0 +1,637 @@
+from typing import Optional, List, Union
+import math
+import functools
+
+import torch
+from torch_runstats.scatter import scatter
+
+from torch import fx
+from e3nn.util.codegen import CodeGenMixin
+from e3nn import o3
+from e3nn.o3 import TensorProduct, Linear
+from e3nn.math import normalize2mom
+from e3nn.util.jit import compile_mode
+
+from dptb.data import AtomicDataDict
+from ..radial_basis import BesselBasis
+from dptb.nn.graph_mixin import GraphModuleMixin
+from dptb.nn.embedding.from_deephe3.deephe3 import tp_path_exists
+from dptb.data import _keys
+from dptb.nn.cutoff import cosine_cutoff, polynomial_cutoff
+import math
+
+from math import ceil
+
+@compile_mode("script")
+class Allegro_Module(GraphModuleMixin, torch.nn.Module):
+ # saved params
+ num_layers: int
+
+ field: str
+ out_field: str
+ num_types: int
+ env_embed_mul: int
+ weight_numel: int
+ latent_resnet: bool
+ embed_initial_edge: bool
+
+ # internal values
+ _env_builder_w_index: List[int]
+ _env_builder_n_irreps: int
+ _input_pad: int
+
+ def __init__(
+ self,
+ # required params
+ num_layers: int,
+ num_types: int,
+ r_max: float,
+ avg_num_neighbors: Optional[float] = None,
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ per_layer_cutoffs: Optional[List[float]] = None,
+ cutoff_type: str = "polynomial",
+ # general hyperparameters:
+ field: str = AtomicDataDict.EDGE_ATTRS_KEY,
+ edge_invariant_field: str = AtomicDataDict.EDGE_EMBEDDING_KEY,
+ node_invariant_field: str = AtomicDataDict.NODE_ATTRS_KEY,
+ env_embed_multiplicity: int = 32,
+ embed_initial_edge: bool = True,
+ linear_after_env_embed: bool = False,
+ nonscalars_include_parity: bool = True,
+ # MLP parameters:
+ two_body_latent=ScalarMLPFunction,
+ two_body_latent_kwargs={},
+ env_embed=ScalarMLPFunction,
+ env_embed_kwargs={},
+ latent=ScalarMLPFunction,
+ latent_kwargs={},
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ latent_out_field: Optional[str] = _keys.EDGE_FEATURES,
+ # Performance parameters:
+ pad_to_alignment: int = 1,
+ sparse_mode: Optional[str] = None,
+ # Other:
+ irreps_in=None,
+ ):
+ super().__init__()
+ SCALAR = o3.Irrep("0e") # define for convinience
+
+ # save parameters
+ assert (
+ num_layers >= 1
+ ) # zero layers is "two body", but we don't need to support that fallback case
+ self.num_layers = num_layers
+ self.nonscalars_include_parity = nonscalars_include_parity
+ self.field = field
+ self.latent_out_field = latent_out_field
+ self.edge_invariant_field = edge_invariant_field
+ self.node_invariant_field = node_invariant_field
+ self.latent_resnet = latent_resnet
+ self.env_embed_mul = env_embed_multiplicity
+ self.r_start_cos_ratio = r_start_cos_ratio
+ self.polynomial_cutoff_p = float(PolynomialCutoff_p)
+ self.cutoff_type = cutoff_type
+ assert cutoff_type in ("cosine", "polynomial")
+ self.embed_initial_edge = embed_initial_edge
+ self.avg_num_neighbors = avg_num_neighbors
+ self.linear_after_env_embed = linear_after_env_embed
+ self.num_types = num_types
+
+ # set up irreps
+ self._init_irreps(
+ irreps_in=irreps_in,
+ required_irreps_in=[
+ self.field,
+ self.edge_invariant_field,
+ self.node_invariant_field,
+ ],
+ )
+
+ # for normalization of env embed sums
+ # one per layer
+ self.register_buffer(
+ "env_sum_normalizations",
+ # dividing by sqrt(N)
+ torch.as_tensor([avg_num_neighbors] * num_layers).rsqrt(),
+ )
+
+ latent = functools.partial(latent, **latent_kwargs)
+ env_embed = functools.partial(env_embed, **env_embed_kwargs)
+
+ self.latents = torch.nn.ModuleList([])
+ self.env_embed_mlps = torch.nn.ModuleList([])
+ self.tps = torch.nn.ModuleList([])
+ self.linears = torch.nn.ModuleList([])
+ self.env_linears = torch.nn.ModuleList([])
+
+ # Embed to the spharm * it as mul
+ input_irreps = self.irreps_in[self.field]
+ # this is not inherant, but no reason to fix right now:
+ assert all(mul == 1 for mul, ir in input_irreps)
+ env_embed_irreps = o3.Irreps([(1, ir) for _, ir in input_irreps])
+ assert (
+ env_embed_irreps[0].ir == SCALAR
+ ), "env_embed_irreps must start with scalars"
+ self._input_pad = (
+ int(math.ceil(env_embed_irreps.dim / pad_to_alignment)) * pad_to_alignment
+ ) - env_embed_irreps.dim
+ self.register_buffer("_zero", torch.zeros(1, 1))
+
+ # Initially, we have the B(r)Y(\vec{r})-projection of the edges
+ # (possibly embedded)
+ if self.embed_initial_edge:
+ arg_irreps = env_embed_irreps
+ else:
+ arg_irreps = input_irreps
+
+ # - begin irreps -
+ # start to build up the irreps for the iterated TPs
+ tps_irreps = [arg_irreps]
+
+ for layer_idx in range(num_layers):
+ # Create higher order terms cause there are more TPs coming
+ if layer_idx == 0:
+ # Add parity irreps
+ ir_out = []
+ for (mul, ir) in env_embed_irreps:
+ if self.nonscalars_include_parity: # make all irreps except 0e have o and e
+ # add both parity options
+ ir_out.append((1, (ir.l, 1)))
+ ir_out.append((1, (ir.l, -1)))
+ else:
+ # add only the parity option seen in the inputs
+ ir_out.append((1, ir))
+
+ ir_out = o3.Irreps(ir_out)
+
+ if layer_idx == self.num_layers - 1:
+ # ^ means we're doing the last layer
+ # No more TPs follow this, so only need scalars
+ ir_out = o3.Irreps([(1, (0, 1))])
+
+ # Prune impossible paths
+ ir_out = o3.Irreps(
+ [
+ (mul, ir)
+ for mul, ir in ir_out
+ if tp_path_exists(arg_irreps, env_embed_irreps, ir)
+ ]
+ )
+
+ # the argument to the next tensor product is the output of this one
+ arg_irreps = ir_out
+ tps_irreps.append(ir_out)
+ # - end build irreps -
+
+ # == Remove unneeded paths ==
+ out_irreps = tps_irreps[-1]
+ new_tps_irreps = [out_irreps]
+ for arg_irreps in reversed(tps_irreps[:-1]):
+ new_arg_irreps = []
+ for mul, arg_ir in arg_irreps:
+ for _, env_ir in env_embed_irreps:
+ if any(i in out_irreps for i in arg_ir * env_ir):
+ # arg_ir is useful: arg_ir * env_ir has a path to something we want
+ new_arg_irreps.append((mul, arg_ir))
+ # once its useful once, we keep it no matter what
+ break
+ new_arg_irreps = o3.Irreps(new_arg_irreps)
+ new_tps_irreps.append(new_arg_irreps)
+ out_irreps = new_arg_irreps
+
+ assert len(new_tps_irreps) == len(tps_irreps)
+ tps_irreps = list(reversed(new_tps_irreps))
+ del new_tps_irreps
+
+ assert tps_irreps[-1].lmax == 0
+
+ tps_irreps_in = tps_irreps[:-1]
+ tps_irreps_out = tps_irreps[1:]
+ del tps_irreps
+
+ # Environment builder:
+ self._env_weighter = MakeWeightedChannels(
+ irreps_in=input_irreps,
+ multiplicity_out=env_embed_multiplicity,
+ pad_to_alignment=pad_to_alignment,
+ )
+
+ self._n_scalar_outs = []
+
+ # == Build TPs ==
+ for layer_idx, (arg_irreps, out_irreps) in enumerate(
+ zip(tps_irreps_in, tps_irreps_out)
+ ):
+ # Make the env embed linear
+ if self.linear_after_env_embed:
+ self.env_linears.append(
+ Linear(
+ [(env_embed_multiplicity, ir) for _, ir in env_embed_irreps],
+ [(env_embed_multiplicity, ir) for _, ir in env_embed_irreps],
+ shared_weights=True,
+ internal_weights=True,
+ )
+ )
+ else:
+ self.env_linears.append(torch.nn.Identity())
+ # Make TP
+ tmp_i_out: int = 0
+ instr = []
+ n_scalar_outs: int = 0
+ full_out_irreps = []
+ for i_out, (_, ir_out) in enumerate(out_irreps):
+ for i_1, (_, ir_1) in enumerate(arg_irreps):
+ for i_2, (_, ir_2) in enumerate(env_embed_irreps):
+ if ir_out in ir_1 * ir_2:
+ if ir_out == SCALAR:
+ n_scalar_outs += 1
+ instr.append((i_1, i_2, tmp_i_out))
+ full_out_irreps.append((env_embed_multiplicity, ir_out))
+ tmp_i_out += 1
+ full_out_irreps = o3.Irreps(full_out_irreps)
+ self._n_scalar_outs.append(n_scalar_outs)
+ assert all(ir == SCALAR for _, ir in full_out_irreps[:n_scalar_outs])
+ tp = Contracter(
+ irreps_in1=o3.Irreps(
+ [
+ (
+ (
+ env_embed_multiplicity
+ if layer_idx > 0 or self.embed_initial_edge
+ else 1
+ ),
+ ir,
+ )
+ for _, ir in arg_irreps
+ ]
+ ),
+ irreps_in2=o3.Irreps(
+ [(env_embed_multiplicity, ir) for _, ir in env_embed_irreps]
+ ),
+ irreps_out=o3.Irreps(
+ [(env_embed_multiplicity, ir) for _, ir in full_out_irreps]
+ ),
+ instructions=instr,
+ # For the first layer, we have the unprocessed edges
+ # coming in from the input if `not self.embed_initial_edge`.
+ # These don't match the embedding in mul, so we have
+ # to use uvv --- since the input edges should be mul
+ # of one in normal circumstances, this is still plenty fast.
+ # For this reason it also doesn't increase the number of weights.
+ connection_mode=(
+ "uuu" if layer_idx > 0 or self.embed_initial_edge else "uvv"
+ ),
+ shared_weights=False,
+ has_weight=False,
+ pad_to_alignment=pad_to_alignment,
+ sparse_mode=sparse_mode,
+ )
+ self.tps.append(tp)
+ # we extract the scalars from the first irrep of the tp
+ assert out_irreps[0].ir == SCALAR
+
+ # Make env embed mlp
+ generate_n_weights = (
+ self._env_weighter.weight_numel
+ ) # the weight for the edge embedding
+ if layer_idx == 0 and self.embed_initial_edge:
+ # also need weights to embed the edge itself
+ # this is because the 2 body latent is mixed in with the first layer
+ # in terms of code
+ generate_n_weights += self._env_weighter.weight_numel
+
+ # the linear acts after the extractor
+ # this linear act on the reduced V and gives a out_irreps that is just reduced
+ self.linears.append(
+ Linear(
+ full_out_irreps,
+ [(env_embed_multiplicity, ir) for _, ir in out_irreps],
+ shared_weights=True,
+ internal_weights=True,
+ pad_to_alignment=pad_to_alignment,
+ )
+ )
+
+ if layer_idx == 0:
+ # at the first layer, we have no invariants from previous TPs
+ self.latents.append(
+ two_body_latent(
+ mlp_input_dimension=(
+ (
+ # Node invariants for center and neighbor (chemistry)
+ 2 * self.irreps_in[self.node_invariant_field].num_irreps
+ # Plus edge invariants for the edge (radius).
+ + self.irreps_in[self.edge_invariant_field].num_irreps
+ )
+ ),
+ mlp_output_dimension=None,
+ **two_body_latent_kwargs,
+ )
+ )
+
+ else:
+ self.latents.append(
+ latent(
+ mlp_input_dimension=(
+ (
+ # the embedded latent invariants from the previous layer(s)
+ self.latents[-1].out_features
+ # and the invariants extracted from the last layer's TP:
+ + env_embed_multiplicity * n_scalar_outs
+ )
+ ),
+ mlp_output_dimension=None,
+ )
+ )
+ # the env embed MLP takes the last latent's output as input
+ # and outputs enough weights for the env embedder
+ self.env_embed_mlps.append(
+ env_embed(
+ mlp_input_dimension=self.latents[-1].out_features,
+ mlp_output_dimension=generate_n_weights,
+ )
+ )
+
+ # For the final layer, we specialize:
+ # we don't need to propagate nonscalars, so there is no TP
+ # thus we only need the latent:
+ self.final_latent = latent(
+ mlp_input_dimension=self.latents[-1].out_features
+ + env_embed_multiplicity * n_scalar_outs,
+ mlp_output_dimension=None,
+ )
+ # - end build modules -
+
+ # - layer resnet update weights -
+ if latent_resnet_update_ratios is None:
+ # We initialize to zeros, which under the sigmoid() become 0.5
+ # so 1/2 * layer_1 + 1/4 * layer_2 + ...
+ # note that the sigmoid of these are the factor _between_ layers
+ # so the first entry is the ratio for the latent resnet of the first and second layers, etc.
+ # e.g. if there are 3 layers, there are 2 ratios: l1:l2, l2:l3
+ latent_resnet_update_params = torch.zeros(self.num_layers)
+ else:
+ latent_resnet_update_ratios = torch.as_tensor(
+ latent_resnet_update_ratios, dtype=torch.get_default_dtype()
+ )
+ assert latent_resnet_update_ratios.min() > 0.0
+ assert latent_resnet_update_ratios.min() < 1.0
+ latent_resnet_update_params = torch.special.logit(
+ latent_resnet_update_ratios
+ )
+ # The sigmoid is mostly saturated at ±6, keep it in a reasonable range
+ latent_resnet_update_params.clamp_(-6.0, 6.0)
+ assert latent_resnet_update_params.shape == (
+ num_layers,
+ ), f"There must be {num_layers} layer resnet update ratios (layer0:layer1, layer1:layer2)"
+ if latent_resnet_update_ratios_learnable:
+ self._latent_resnet_update_params = torch.nn.Parameter(
+ latent_resnet_update_params
+ )
+ else:
+ self.register_buffer(
+ "_latent_resnet_update_params", latent_resnet_update_params
+ )
+
+ # - Per-layer cutoffs -
+ if per_layer_cutoffs is None:
+ per_layer_cutoffs = torch.full((num_layers + 1,), r_max)
+ self.register_buffer("per_layer_cutoffs", torch.as_tensor(per_layer_cutoffs))
+ assert torch.all(self.per_layer_cutoffs <= r_max)
+ assert self.per_layer_cutoffs.shape == (
+ num_layers + 1,
+ ), "Must be one per-layer cutoff for layer 0 and every layer for a total of {num_layers} cutoffs (the first applies to the two body latent, which is 'layer 0')"
+ assert (
+ self.per_layer_cutoffs[1:] <= self.per_layer_cutoffs[:-1]
+ ).all(), "Per-layer cutoffs must be equal or decreasing"
+ assert (
+ self.per_layer_cutoffs.min() > 0
+ ), "Per-layer cutoffs must be >0. To remove higher layers entirely, lower `num_layers`."
+ self._latent_dim = self.final_latent.out_features
+ self.register_buffer("_zero", torch.as_tensor(0.0))
+
+ self.irreps_out.update(
+ {
+ self.latent_out_field: o3.Irreps(
+ [(self.final_latent.out_features, (0, 1))]
+ ),
+ }
+ )
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ """Evaluate.
+
+ :param data: AtomicDataDict.Type
+ :return: AtomicDataDict.Type
+ """
+ edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]
+ edge_neighbor = data[AtomicDataDict.EDGE_INDEX_KEY][1]
+
+ edge_attr = data[self.field]
+ # pad edge_attr
+ if self._input_pad > 0:
+ edge_attr = torch.cat(
+ (
+ edge_attr,
+ self._zero.expand(len(edge_attr), self._input_pad),
+ ),
+ dim=-1,
+ )
+
+ edge_length = data[AtomicDataDict.EDGE_LENGTH_KEY]
+ num_edges: int = len(edge_attr)
+ edge_invariants = data[self.edge_invariant_field]
+ node_invariants = data[self.node_invariant_field]
+ # pre-declare variables as Tensors for TorchScript
+ scalars = self._zero
+ coefficient_old = scalars
+ coefficient_new = scalars
+ # Initialize state
+ latents = torch.zeros(
+ (num_edges, self._latent_dim),
+ dtype=edge_attr.dtype,
+ device=edge_attr.device,
+ )
+ active_edges = torch.arange(
+ num_edges,
+ device=edge_attr.device,
+ )
+
+ # For the first layer, we use the input invariants:
+ # The center and neighbor invariants and edge invariants
+ latent_inputs_to_cat = [
+ node_invariants[edge_center],
+ node_invariants[edge_neighbor],
+ edge_invariants,
+ ]
+ # The nonscalar features. Initially, the edge data.
+ features = edge_attr
+
+ layer_index: int = 0
+ # compute the sigmoids vectorized instead of each loop
+ layer_update_coefficients = self._latent_resnet_update_params.sigmoid()
+
+ # Vectorized precompute per layer cutoffs
+ if self.cutoff_type == "cosine":
+ cutoff_coeffs_all = cosine_cutoff(
+ edge_length,
+ self.per_layer_cutoffs,
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ )
+ elif self.cutoff_type == "polynomial":
+ cutoff_coeffs_all = polynomial_cutoff(
+ edge_length, self.per_layer_cutoffs, p=self.polynomial_cutoff_p
+ )
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+
+ # !!!! REMEMBER !!!! update final layer if update the code in main loop!!!
+ # This goes through layer0, layer1, ..., layer_max-1
+ for latent, env_embed_mlp, env_linear, tp, linear in zip(
+ self.latents, self.env_embed_mlps, self.env_linears, self.tps, self.linears
+ ):
+ # Determine which edges are still in play
+ cutoff_coeffs = cutoff_coeffs_all[layer_index]
+ prev_mask = cutoff_coeffs[active_edges] > 0
+ active_edges = (cutoff_coeffs > 0).nonzero().squeeze(-1)
+
+ # Compute latents
+ new_latents = latent(torch.cat(latent_inputs_to_cat, dim=-1)[prev_mask])
+ # Apply cutoff, which propagates through to everything else
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+
+ if self.latent_resnet and layer_index > 0:
+ this_layer_update_coeff = layer_update_coefficients[layer_index - 1]
+ # At init, we assume new and old to be approximately uncorrelated
+ # Thus their variances add
+ # we always want the latent space to be normalized to variance = 1.0,
+ # because it is critical for learnability. Still, we want to preserve
+ # the _relative_ magnitudes of the current latent and the residual update
+ # to be controled by `this_layer_update_coeff`
+ # Solving the simple system for the two coefficients:
+ # a^2 + b^2 = 1 (variances add) & a * this_layer_update_coeff = b
+ # gives
+ # a = 1 / sqrt(1 + this_layer_update_coeff^2) & b = this_layer_update_coeff / sqrt(1 + this_layer_update_coeff^2)
+ # rsqrt is reciprocal sqrt
+ coefficient_old = torch.rsqrt(this_layer_update_coeff.square() + 1)
+ coefficient_new = this_layer_update_coeff * coefficient_old
+ # Residual update
+ # Note that it only runs when there are latents to resnet with, so not at the first layer
+ # index_add adds only to the edges for which we have something to contribute
+ latents = torch.index_add(
+ coefficient_old * latents,
+ 0,
+ active_edges,
+ coefficient_new * new_latents,
+ )
+ else:
+ # Normal (non-residual) update
+ # index_copy replaces, unlike index_add
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+
+ # From the latents, compute the weights for active edges:
+ weights = env_embed_mlp(latents[active_edges])
+ w_index: int = 0
+
+ if self.embed_initial_edge and layer_index == 0:
+ # embed initial edge
+ env_w = weights.narrow(-1, w_index, self._env_weighter.weight_numel)
+ w_index += self._env_weighter.weight_numel
+ features = self._env_weighter(
+ features[prev_mask], env_w
+ ) # features is edge_attr
+ else:
+ # just take the previous features that we still need
+ features = features[prev_mask]
+
+ # Extract weights for the environment builder
+ env_w = weights.narrow(-1, w_index, self._env_weighter.weight_numel)
+ w_index += self._env_weighter.weight_numel
+
+ # Build the local environments
+ # This local environment should only be a sum over neighbors
+ # who are within the cutoff of the _current_ layer
+ # Those are the active edges, which are the only ones we
+ # have weights for (env_w) anyway.
+ # So we mask out the edges in the sum:
+ local_env_per_edge = scatter(
+ self._env_weighter(edge_attr[active_edges], env_w),
+ edge_center[active_edges],
+ dim=0,
+ )
+ if self.env_sum_normalizations.ndim < 2:
+ # it's a scalar per layer
+ norm_const = self.env_sum_normalizations[layer_index]
+ else:
+ # it's per type
+ # get shape [N_atom, 1] for broadcasting
+ norm_const = self.env_sum_normalizations[
+ layer_index, data[AtomicDataDict.ATOM_TYPE_KEY]
+ ].unsqueeze(-1)
+ local_env_per_edge = local_env_per_edge * norm_const
+ local_env_per_edge = env_linear(local_env_per_edge)
+ # Copy to get per-edge
+ # Large allocation, but no better way to do this:
+ local_env_per_edge = local_env_per_edge[edge_center[active_edges]]
+
+ # Now do the TP
+ # recursively tp current features with the environment embeddings
+ features = tp(features, local_env_per_edge)
+
+ # Get invariants
+ # features has shape [z][mul][k]
+ # we know scalars are first
+ scalars = features[:, :, : self._n_scalar_outs[layer_index]].reshape(
+ features.shape[0], -1
+ )
+
+ # do the linear
+ features = linear(features)
+
+ # For layer2+, use the previous latents and scalars
+ # This makes it deep
+ latent_inputs_to_cat = [
+ latents[active_edges],
+ scalars,
+ ]
+
+ # increment counter
+ layer_index += 1
+
+ # - final layer -
+ # due to TorchScript limitations, we have to
+ # copy and repeat the code here --- no way to
+ # escape the final iteration of the loop early
+ cutoff_coeffs = cutoff_coeffs_all[layer_index]
+ prev_mask = cutoff_coeffs[active_edges] > 0
+ active_edges = (cutoff_coeffs > 0).nonzero().squeeze(-1)
+ new_latents = self.final_latent(
+ torch.cat(latent_inputs_to_cat, dim=-1)[prev_mask]
+ )
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ if self.latent_resnet:
+ this_layer_update_coeff = layer_update_coefficients[layer_index - 1]
+ coefficient_old = torch.rsqrt(this_layer_update_coeff.square() + 1)
+ coefficient_new = this_layer_update_coeff * coefficient_old
+ latents = torch.index_add(
+ coefficient_old * latents,
+ 0,
+ active_edges,
+ coefficient_new * new_latents,
+ )
+ else:
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+ # - end final layer -
+
+ # final latents
+ data[self.latent_out_field] = latents
+
+ return data
\ No newline at end of file
diff --git a/dptb/nn/embedding/baseline.py b/dptb/nn/embedding/baseline.py
new file mode 100644
index 00000000..703c419f
--- /dev/null
+++ b/dptb/nn/embedding/baseline.py
@@ -0,0 +1,316 @@
+from torch_geometric.nn import MessagePassing
+from torch_geometric.nn import Aggregation
+import torch
+from typing import Optional, Tuple, Union
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..base import ResNet, FFN
+from torch.nn import Linear
+from dptb.utils.constants import dtype_dict
+from ..type_encode.one_hot import OneHotAtomEncoding
+from ..cutoff import polynomial_cutoff
+from ..radial_basis import BesselBasis
+from torch_runstats.scatter import scatter
+
+def get_neuron_config(nl):
+ n = len(nl)
+ if n % 2 == 0:
+ d_out = nl[-1]
+ nl = nl[:-1]
+ config = []
+ for i in range(1,len(nl)-1, 2):
+ config.append({'in_features': nl[i-1], 'hidden_features': nl[i], 'out_features': nl[i+1]})
+
+ if n % 2 == 0:
+ config.append({'in_features': nl[-1], 'out_features': d_out})
+
+ return config
+
+@Embedding.register("baseline")
+class BASELINE(torch.nn.Module):
+ def __init__(
+ self,
+ rc:Union[float, torch.Tensor],
+ p:Union[int, torch.LongTensor],
+ n_axis: Union[int, torch.LongTensor, None]=None,
+ n_basis: Union[int, torch.LongTensor, None]=None,
+ n_radial: Union[int, torch.LongTensor, None]=None,
+ n_sqrt_radial: Union[int, torch.LongTensor, None]=None,
+ n_atom: int=1,
+ n_layer: int=1,
+ radial_net: dict={},
+ hidden_net: dict={},
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,):
+
+ super(BASELINE, self).__init__()
+
+ assert n_axis <= n_sqrt_radial
+ self.n_radial = n_radial
+ self.n_sqrt_radial = n_sqrt_radial
+ self.n_axis = n_axis
+
+ if isinstance(rc, float):
+ self.rc = torch.tensor(rc, dtype=dtype, device=device)
+ else:
+ self.rc = rc
+
+ self.p = p
+ self.node_emb_layer = _NODE_EMB(rc=self.rc, p=p, n_axis=n_axis, n_basis=n_basis, n_radial=n_radial, n_sqrt_radial=n_sqrt_radial, n_atom=n_atom, radial_net=radial_net, dtype=dtype, device=device)
+ self.layers = torch.nn.ModuleList([])
+ for _ in range(n_layer):
+ self.layers.append(BaselineLayer(n_atom=n_atom, rc=self.rc, p=p, n_radial=n_radial, n_sqrt_radial=n_sqrt_radial, n_axis=n_axis, n_hidden=n_axis*n_sqrt_radial, hidden_net=hidden_net, radial_net=radial_net, dtype=dtype, device=device))
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = self.onehot(data)
+ data = AtomicDataDict.with_env_vectors(data, with_lengths=True)
+ data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)
+
+ env_radial, edge_radial, node_emb, env_hidden, edge_hidden = self.node_emb_layer(
+ env_vectors=data[AtomicDataDict.ENV_VECTORS_KEY],
+ atom_attr=data[AtomicDataDict.NODE_ATTRS_KEY],
+ env_index=data[AtomicDataDict.ENV_INDEX_KEY],
+ edge_index=data[AtomicDataDict.EDGE_INDEX_KEY],
+ env_length=data[AtomicDataDict.ENV_LENGTH_KEY],
+ edge_length=data[AtomicDataDict.EDGE_LENGTH_KEY],
+ )
+
+
+ for layer in self.layers:
+ env_radial, env_hidden, edge_radial, edge_hidden, node_emb = layer(
+ env_length=data[AtomicDataDict.ENV_LENGTH_KEY],
+ edge_length=data[AtomicDataDict.EDGE_LENGTH_KEY],
+ env_index=data[AtomicDataDict.ENV_INDEX_KEY],
+ edge_index=data[AtomicDataDict.EDGE_INDEX_KEY],
+ env_radial=env_radial,
+ edge_radial=edge_radial,
+ node_emb=node_emb,
+ env_hidden=env_hidden,
+ edge_hidden=edge_hidden,
+ )
+
+ # env_length = data[AtomicDataDict.ENV_LENGTH_KEY]
+ # data[AtomicDataDict.NODE_FEATURES_KEY] = \
+ # scatter(src=polynomial_cutoff(x=env_length, r_max=self.rc, p=self.p).reshape(-1, 1) * env_radial, index=data[AtomicDataDict.ENV_INDEX_KEY][0], dim=0, reduce="sum")
+ data[AtomicDataDict.NODE_FEATURES_KEY] = node_emb
+
+ data[AtomicDataDict.EDGE_FEATURES_KEY] = edge_radial
+
+ return data
+
+ @property
+ def out_edge_dim(self):
+ return self.n_radial
+
+ @property
+ def out_node_dim(self):
+ return self.n_sqrt_radial * self.n_axis
+
+class SE2Aggregation(Aggregation):
+ def forward(self, x: torch.Tensor, index: torch.LongTensor, **kwargs):
+ """_summary_
+
+ Parameters
+ ----------
+ x : tensor of size (N, d), where d dimension looks like (emb(s(r)), \hat{x}, \hat{y}, \hat{z})
+ The is the embedding of the env_vectors
+ index : _type_
+ _description_
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+ direct_vec = x[:, -3:]
+ x = x[:,:-3].unsqueeze(-1) * direct_vec.unsqueeze(1) # [N_env, D, 3]
+ return self.reduce(x, index, reduce="mean", dim=0) # [N_atom, D, 3] following the orders of atom index.
+
+
+class _NODE_EMB(MessagePassing):
+ def __init__(
+ self,
+ rc:Union[float, torch.Tensor],
+ p:Union[int, torch.LongTensor],
+ n_axis: Union[int, torch.LongTensor, None]=None,
+ n_basis: Union[int, torch.LongTensor, None]=None,
+ n_sqrt_radial: Union[int, torch.LongTensor, None]=None,
+ n_radial: Union[int, torch.LongTensor, None]=None,
+ aggr: SE2Aggregation=SE2Aggregation(),
+ radial_net: dict={},
+ n_atom: int=1,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"), **kwargs):
+
+ super(_NODE_EMB, self).__init__(aggr=aggr, **kwargs)
+
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ if n_axis == None:
+ self.n_axis = n_sqrt_radial
+ else:
+ self.n_axis = n_axis
+
+ radial_net["config"] = get_neuron_config([2*n_atom+n_basis]+radial_net["neurons"]+[n_radial])
+ self.mlp_radial = FFN(**radial_net, device=device, dtype=dtype)
+ radial_net["config"] = get_neuron_config([2*n_atom+n_basis]+radial_net["neurons"]+[n_sqrt_radial])
+ self.mlp_sqrt_radial = FFN(**radial_net, device=device, dtype=dtype)
+ self.mlp_emb = Linear(n_radial, self.n_axis*n_sqrt_radial, device=device, dtype=dtype)
+ if isinstance(rc, float):
+ self.rc = torch.tensor(rc, dtype=dtype, device=device)
+ else:
+ self.rc = rc
+
+ self.p = p
+
+ self.n_axis = n_axis
+ self.device = device
+ self.dtype = dtype
+
+ self.n_out = self.n_axis * n_sqrt_radial
+
+ self.bessel = BesselBasis(r_max=rc, num_basis=n_basis, trainable=True)
+ self.node_layer_norm = torch.nn.LayerNorm(self.n_out, elementwise_affine=True)
+ self.edge_layer_norm = torch.nn.LayerNorm(n_radial, elementwise_affine=True)
+
+ def forward(self, env_vectors, atom_attr, env_index, edge_index, env_length, edge_length):
+ n_env = env_index.shape[1]
+ n_edge = edge_index.shape[1]
+ env_attr = atom_attr[env_index].transpose(1,0).reshape(n_env,-1)
+ edge_attr = atom_attr[edge_index].transpose(1,0).reshape(n_edge,-1)
+ ud_env = polynomial_cutoff(x=env_length, r_max=self.rc, p=self.p).reshape(-1, 1)
+ ud_edge = polynomial_cutoff(x=edge_length, r_max=self.rc, p=self.p).reshape(-1, 1)
+
+ env_sqrt_radial = self.mlp_sqrt_radial(torch.cat([env_attr, ud_env * self.bessel(env_length)], dim=-1)) * ud_env
+
+ env_radial = self.edge_layer_norm(self.mlp_radial(torch.cat([env_attr, ud_env * self.bessel(env_length)], dim=-1))) * ud_env
+ edge_radial = self.edge_layer_norm(self.mlp_radial(torch.cat([edge_attr, ud_edge * self.bessel(edge_length)], dim=-1))) * ud_edge
+
+ node_emb = self.propagate(env_index, env_vectors=env_vectors, env_length=env_length, ud=ud_env, env_sqrt_radial=env_sqrt_radial) # [N_atom, D, 3]
+ env_hidden = self.mlp_emb(env_radial) * (node_emb[env_index[1]]+node_emb[env_index[0]]) * 0.5
+ edge_hidden = self.mlp_emb(edge_radial) * (node_emb[edge_index[1]]+node_emb[edge_index[0]]) * 0.5
+
+ return env_radial, edge_radial, node_emb, env_hidden, edge_hidden
+
+ def message(self, env_vectors, env_length, env_sqrt_radial, ud):
+ snorm = env_length.unsqueeze(-1) * ud
+ env_vectors = snorm * env_vectors / env_length.unsqueeze(-1)
+ return torch.cat([env_sqrt_radial, env_vectors], dim=-1) # [N_env, D_emb + 3]
+
+ def update(self, aggr_out):
+ """_summary_
+
+ Parameters
+ ----------
+ aggr_out : The output of the aggregation layer, which is the mean of the message vectors as size [N, D, 3]
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+ out = torch.bmm(aggr_out, aggr_out.transpose(1, 2))[:,:,:self.n_axis].flatten(start_dim=1, end_dim=2)
+
+
+ return self.node_layer_norm(out) # [N, D*D]
+
+
+class BaselineLayer(MessagePassing):
+ def __init__(
+ self,
+ rc:Union[float, torch.Tensor],
+ p:Union[int, torch.LongTensor],
+ n_radial: int,
+ n_sqrt_radial: int,
+ n_axis: int,
+ n_atom: int,
+ n_hidden: int,
+ radial_net: dict={},
+ hidden_net: dict={},
+ aggr="mean",
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"), **kwargs):
+
+ super(BaselineLayer, self).__init__(aggr=aggr, **kwargs)
+
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ if isinstance(rc, float):
+ self.rc = torch.tensor(rc, dtype=dtype, device=device)
+ else:
+ self.rc = rc
+
+ self.p = p
+
+ self.mlp_emb = Linear(n_radial, n_axis*n_sqrt_radial, device=device, dtype=dtype)
+ hidden_net["config"] = get_neuron_config([n_axis*n_sqrt_radial+n_hidden]+hidden_net["neurons"]+[n_hidden])
+ self.mlp_hid = FFN(**hidden_net, device=device, dtype=dtype)
+ radial_net["config"] = get_neuron_config([n_radial+n_hidden]+radial_net["neurons"]+[n_radial])
+ self.mlp_radial = ResNet(**radial_net, dtype=dtype, device=device)
+
+ self.node_layer_norm = torch.nn.LayerNorm(n_axis*n_sqrt_radial, elementwise_affine=True)
+ self.edge_layer_norm = torch.nn.LayerNorm(n_radial, elementwise_affine=True)
+
+ self.device = device
+ self.dtype = dtype
+
+ def forward(self, env_length, edge_length, edge_index, env_index, env_radial, edge_radial, node_emb, env_hidden, edge_hidden):
+ # n_env = env_index.shape[1]
+ # n_edge = edge_index.shape[1]
+ # env_attr = atom_attr[env_index].transpose(1,0).reshape(n_env,-1)
+ # edge_attr = atom_attr[edge_index].transpose(1,0).reshape(n_edge,-1)
+
+ env_weight = self.mlp_emb(env_radial)
+ # node_emb can descripe the node very well
+ node_emb = 0.89442719 * node_emb + 0.4472 * self.propagate(env_index, node_emb=node_emb[env_index[1]], env_weight=env_weight) # [N_atom, D, 3]
+ # import matplotlib.pyplot as plt
+ # fig = plt.figure(figsize=(15,4))
+ # plt.plot(node_emb.detach().T)
+ # plt.title("node_emb")
+ # plt.show()
+
+ # env_hidden 长得太像了
+ env_hidden = self.mlp_hid(torch.cat([node_emb[env_index[0]], env_hidden], dim=-1))
+ edge_hidden = self.mlp_hid(torch.cat([node_emb[edge_index[0]], edge_hidden], dim=-1))
+ # node_emb = _node_emb + node_emb
+
+ # import matplotlib.pyplot as plt
+ # fig = plt.figure(figsize=(15,4))
+ # plt.plot(edge_hidden.detach().T)
+ # plt.title("edge_hidden")
+ # plt.show()
+
+ ud_env = polynomial_cutoff(x=env_length, r_max=self.rc, p=self.p).reshape(-1, 1)
+ ud_edge = polynomial_cutoff(x=edge_length, r_max=self.rc, p=self.p).reshape(-1, 1)
+ env_radial = 0.89442719 * env_radial + 0.4472 * ud_env * self.edge_layer_norm(self.mlp_radial(torch.cat([env_radial, env_hidden], dim=-1)))
+ edge_radial = 0.89442719 * edge_radial + 0.4472 * ud_edge * self.edge_layer_norm(self.mlp_radial(torch.cat([edge_radial, edge_hidden], dim=-1)))
+
+ return env_radial, env_hidden, edge_radial, edge_hidden, node_emb
+
+ def message(self, node_emb, env_weight):
+
+ return env_weight * node_emb
+
+ def update(self, aggr_out):
+ """_summary_
+
+ Parameters
+ ----------
+ aggr_out : The output of the aggregation layer, which is the mean of the message vectors as size [N, D, 3]
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+
+ aggr_out = aggr_out.reshape(aggr_out.shape[0], -1)
+ return self.node_layer_norm(aggr_out) # [N, D*D]
\ No newline at end of file
diff --git a/dptb/nn/embedding/deephe3.py b/dptb/nn/embedding/deephe3.py
new file mode 100644
index 00000000..912f41e5
--- /dev/null
+++ b/dptb/nn/embedding/deephe3.py
@@ -0,0 +1,99 @@
+from .from_deephe3.deephe3 import Net
+import torch
+import torch.nn as nn
+import e3nn.o3 as o3
+from dptb.data.transforms import OrbitalMapper
+from dptb.data import AtomicData, AtomicDataDict
+from dptb.data.AtomicDataDict import with_edge_vectors, with_env_vectors, with_batch
+from dptb.nn.embedding.emb import Embedding
+from typing import Dict, Union, List, Tuple, Optional, Any
+
+
+@Embedding.register("deeph-e3")
+class E3DeePH(nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ n_atom: int=1,
+ irreps_embed: o3.Irreps=o3.Irreps("64e"),
+ lmax: int=3,
+ irreps_mid: o3.Irreps=o3.Irreps("64x0e+32x1o+16x2e+8x3o+8x4e+4x5o"),
+ n_layer: int=3,
+ rc: float=5.0,
+ n_basis: int=128,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+
+ super(E3DeePH, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ self.device = device
+
+ irreps_mid = o3.Irreps(irreps_mid)
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb")
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+
+ self.idp.get_irreps(no_parity=False)
+ irreps_sh=o3.Irreps([(1, (i, (-1) ** i)) for i in range(lmax + 1)])
+ # if not no_parity:
+ # irreps_sh=o3.Irreps([(1, (i, (-1) ** i)) for i in range(lmax + 1)])
+ # else:
+ # irreps_sh=o3.Irreps([(1, (i, 1)) for i in range(lmax + 1)])
+
+ self.net = Net(
+ num_species=n_atom,
+ irreps_embed_node=irreps_embed,
+ irreps_sh=irreps_sh,
+ irreps_mid_node=irreps_mid,
+ irreps_post_node=self.idp.orbpair_irreps.sort()[0].simplify(), # it can be derived from the basis
+ irreps_out_node=self.idp.orbpair_irreps, # it can be dervied from the basis
+ irreps_edge_init=irreps_embed,
+ irreps_mid_edge=irreps_mid,
+ irreps_post_edge=self.idp.orbpair_irreps.sort()[0].simplify(), # it can be dervied from the basis
+ irreps_out_edge=self.idp.orbpair_irreps, # it can be dervied from the basis
+ num_block=n_layer,
+ r_max=rc,
+ use_sc=False,
+ no_parity=False,
+ use_sbf=False,
+ selftp=False,
+ edge_upd=True,
+ only_ij=False,
+ num_basis=n_basis
+ )
+
+
+ self.net.to(self.device)
+
+ self.out_irreps = self.idp.orbpair_irreps
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = with_edge_vectors(data, with_lengths=True)
+ data = with_batch(data)
+
+ node_feature, edge_feature = self.net(data)
+ data[AtomicDataDict.NODE_FEATURES_KEY] = node_feature
+ data[AtomicDataDict.EDGE_FEATURES_KEY] = edge_feature
+
+ return data
+
+ @property
+ def out_edge_irreps(self):
+ return self.out_irreps
+
+ @property
+ def out_node_irreps(self):
+ return self.out_irreps
diff --git a/dptb/nn/embedding/e3baseline.py b/dptb/nn/embedding/e3baseline.py
new file mode 100644
index 00000000..8617bf7e
--- /dev/null
+++ b/dptb/nn/embedding/e3baseline.py
@@ -0,0 +1,926 @@
+from typing import Optional, List, Union, Dict
+import math
+import functools
+import warnings
+
+import torch
+from torch_runstats.scatter import scatter
+
+from torch import fx
+from e3nn.util.codegen import CodeGenMixin
+from e3nn import o3
+from e3nn.nn import Gate, Activation
+from e3nn.nn._batchnorm import BatchNorm
+from e3nn.o3 import TensorProduct, Linear, SphericalHarmonics, FullyConnectedTensorProduct
+from e3nn.math import normalize2mom
+from e3nn.util.jit import compile_mode
+
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..radial_basis import BesselBasis
+from dptb.nn.graph_mixin import GraphModuleMixin
+from dptb.nn.embedding.from_deephe3.deephe3 import tp_path_exists
+from dptb.data import _keys
+from dptb.nn.cutoff import cosine_cutoff, polynomial_cutoff
+import math
+from dptb.data.transforms import OrbitalMapper
+from ..type_encode.one_hot import OneHotAtomEncoding
+from dptb.data.AtomicDataDict import with_edge_vectors, with_env_vectors, with_batch
+
+from math import ceil
+
+@Embedding.register("e3baseline")
+class E3BaseLineModel(torch.nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ # required params
+ n_atom: int=1,
+ n_layers: int=3,
+ n_radial_basis: int=10,
+ r_max: float=5.0,
+ lmax: int=4,
+ irreps_hidden: o3.Irreps=None,
+ avg_num_neighbors: Optional[float] = None,
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ sh_normalized: bool = True,
+ sh_normalization: str = "component",
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [256, 256, 512],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+
+ super(E3BaseLineModel, self).__init__()
+
+ irreps_hidden = o3.Irreps(irreps_hidden)
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ self.device = device
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb")
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.idp.get_irreps(no_parity=False)
+
+ irreps_sh=o3.Irreps([(1, (i, (-1) ** i)) for i in range(lmax + 1)])
+ orbpair_irreps = self.idp.orbpair_irreps.sort()[0].simplify()
+
+ # check if the irreps setting satisfied the requirement of idp
+ irreps_out = []
+ for mul, ir1 in irreps_hidden:
+ for _, ir2 in orbpair_irreps:
+ irreps_out += [o3.Irrep(str(irr)) for irr in ir1*ir2]
+ irreps_out = o3.Irreps(irreps_out).sort()[0].simplify()
+
+ assert all(ir in irreps_out for _, ir in orbpair_irreps), "hidden irreps should at least cover all the reqired irreps in the hamiltonian data {}".format(pair_irreps)
+
+ self.sh = SphericalHarmonics(
+ irreps_sh, sh_normalized, sh_normalization
+ )
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+
+ self.init_layer = InitLayer(
+ idp=self.idp,
+ num_types=n_atom,
+ n_radial_basis=n_radial_basis,
+ r_max=r_max,
+ irreps_sh=irreps_sh,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ two_body_latent_kwargs=latent_kwargs,
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio=r_start_cos_ratio,
+ PolynomialCutoff_p=PolynomialCutoff_p,
+ cutoff_type=cutoff_type,
+ device=device,
+ dtype=dtype,
+ )
+
+ self.layers = torch.nn.ModuleList()
+ latent_in =latent_kwargs["mlp_latent_dimensions"][-1]
+ # actually, we can derive the least required irreps_in and out from the idp's node and pair irreps
+ for i in range(n_layers):
+ if i == 0:
+ irreps_in = self.init_layer.irreps_out
+ else:
+ irreps_in = irreps_hidden
+
+ if i == n_layers - 1:
+ irreps_out = orbpair_irreps.sort()[0].simplify()
+ else:
+ irreps_out = irreps_hidden
+
+ self.layers.append(Layer(
+ num_types=n_atom,
+ avg_num_neighbors=avg_num_neighbors,
+ irreps_sh=irreps_sh,
+ irreps_in=irreps_in,
+ irreps_out=irreps_out,
+ # general hyperparameters:
+ linear_after_env_embed=linear_after_env_embed,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ latent_kwargs=latent_kwargs,
+ latent_in=latent_in,
+ latent_resnet=latent_resnet,
+ latent_resnet_update_ratios=latent_resnet_update_ratios,
+ latent_resnet_update_ratios_learnable=latent_resnet_update_ratios_learnable,
+ )
+ )
+
+ # initilize output_layer
+ self.out_edge = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+ self.out_node = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = with_edge_vectors(data, with_lengths=True)
+ # data = with_env_vectors(data, with_lengths=True)
+ data = with_batch(data)
+
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_sh = self.sh(data[_keys.EDGE_VECTORS_KEY][:,[1,2,0]])
+ edge_length = data[_keys.EDGE_LENGTH_KEY]
+
+
+ data = self.onehot(data)
+ node_one_hot = data[_keys.NODE_ATTRS_KEY]
+ atom_type = data[_keys.ATOM_TYPE_KEY].flatten()
+ bond_type = data[_keys.EDGE_TYPE_KEY].flatten()
+ latents, features, cutoff_coeffs, active_edges = self.init_layer(edge_index, bond_type, edge_sh, edge_length, node_one_hot)
+
+ for layer in self.layers:
+ latents, features, cutoff_coeffs, active_edges = layer(edge_index, edge_sh, atom_type, latents, features, cutoff_coeffs, active_edges)
+
+
+ if self.layers[-1].env_sum_normalizations.ndim < 1:
+ norm_const = self.layers[-1].env_sum_normalizations
+ else:
+ norm_const = self.layers[-1].env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ data[_keys.EDGE_FEATURES_KEY] = torch.zeros(edge_index.shape[1], self.idp.orbpair_irreps.dim, dtype=self.dtype, device=self.device)
+ data[_keys.EDGE_FEATURES_KEY] = torch.index_copy(data[_keys.EDGE_FEATURES_KEY], 0, active_edges, self.out_edge(features))
+ node_features = scatter(features, edge_index[0][active_edges], dim=0)
+ data[_keys.NODE_FEATURES_KEY] = self.out_node(node_features * norm_const)
+
+ return data
+
+def tp_path_exists(irreps_in1, irreps_in2, ir_out):
+ irreps_in1 = o3.Irreps(irreps_in1).simplify()
+ irreps_in2 = o3.Irreps(irreps_in2).simplify()
+ ir_out = o3.Irrep(ir_out)
+
+ for _, ir1 in irreps_in1:
+ for _, ir2 in irreps_in2:
+ if ir_out in ir1 * ir2:
+ return True
+ return False
+
+def get_gate_nonlin(irreps_in1, irreps_in2, irreps_out,
+ act={1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+ ):
+ # get gate nonlinearity after tensor product
+ # irreps_in1 and irreps_in2 are irreps to be multiplied in tensor product
+ # irreps_out is desired irreps after gate nonlin
+ # notice that nonlin.irreps_out might not be exactly equal to irreps_out
+
+ irreps_scalars = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l == 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ irreps_gated = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l > 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ if irreps_gated.dim > 0:
+ if tp_path_exists(irreps_in1, irreps_in2, "0e"):
+ ir = "0e"
+ elif tp_path_exists(irreps_in1, irreps_in2, "0o"):
+ ir = "0o"
+ warnings.warn('Using odd representations as gates')
+ else:
+ raise ValueError(
+ f"irreps_in1={irreps_in1} times irreps_in2={irreps_in2} is unable to produce gates needed for irreps_gated={irreps_gated}")
+ else:
+ ir = None
+ irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()
+
+ gate_nonlin = Gate(
+ irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ return gate_nonlin
+
+
+@compile_mode("script")
+class MakeWeightedChannels(torch.nn.Module):
+ weight_numel: int
+ multiplicity_out: Union[int, list]
+ _num_irreps: int
+
+ def __init__(
+ self,
+ irreps_in,
+ multiplicity_out: Union[int, list],
+ pad_to_alignment: int = 1,
+ ):
+ super().__init__()
+ assert all(mul == 1 for mul, _ in irreps_in)
+ assert multiplicity_out >= 1
+ # Each edgewise output multiplicity is a per-irrep weighted sum over the input
+ # So we need to apply the weight for the ith irrep to all DOF in that irrep
+ w_index = []
+ idx = 0
+ self._num_irreps = 0
+ for (mul, ir) in irreps_in:
+ w_index += sum(([ix] * ir.dim for ix in range(idx, idx + mul)), [])
+ idx += mul
+ self._num_irreps += mul
+ # w_index = sum(([i] * ir.dim for i, (mul, ir) in enumerate(irreps_in)), [])
+ # pad to padded length
+ n_pad = (
+ int(ceil(irreps_in.dim / pad_to_alignment)) * pad_to_alignment
+ - irreps_in.dim
+ )
+ # use the last weight, what we use doesn't matter much
+ w_index += [w_index[-1]] * n_pad
+ self.register_buffer("_w_index", torch.as_tensor(w_index, dtype=torch.long))
+ # there is
+ self.multiplicity_out = multiplicity_out
+ self.weight_numel = self._num_irreps * multiplicity_out
+
+ def forward(self, edge_attr, weights):
+ # weights are [z, u, num_i]
+ # edge_attr are [z, i]
+ # i runs over all irreps, which is why the weights need
+ # to be indexed in order to go from [num_i] to [i]
+ return torch.einsum(
+ "zi,zui->zui",
+ edge_attr,
+ weights.view(
+ -1,
+ self.multiplicity_out,
+ self._num_irreps,
+ )[:, :, self._w_index],
+ )
+
+@torch.jit.script
+def ShiftedSoftPlus(x):
+ return torch.nn.functional.softplus(x) - math.log(2.0)
+
+class ScalarMLPFunction(CodeGenMixin, torch.nn.Module):
+ """Module implementing an MLP according to provided options."""
+
+ in_features: int
+ out_features: int
+
+ def __init__(
+ self,
+ mlp_input_dimension: Optional[int],
+ mlp_latent_dimensions: List[int],
+ mlp_output_dimension: Optional[int],
+ mlp_nonlinearity: Optional[str] = "silu",
+ mlp_initialization: str = "normal",
+ mlp_dropout_p: float = 0.0,
+ mlp_batchnorm: bool = False,
+ ):
+ super().__init__()
+ nonlinearity = {
+ None: None,
+ "silu": torch.nn.functional.silu,
+ "ssp": ShiftedSoftPlus,
+ }[mlp_nonlinearity]
+ if nonlinearity is not None:
+ nonlin_const = normalize2mom(nonlinearity).cst
+ else:
+ nonlin_const = 1.0
+
+ dimensions = (
+ ([mlp_input_dimension] if mlp_input_dimension is not None else [])
+ + mlp_latent_dimensions
+ + ([mlp_output_dimension] if mlp_output_dimension is not None else [])
+ )
+ assert len(dimensions) >= 2 # Must have input and output dim
+ num_layers = len(dimensions) - 1
+
+ self.in_features = dimensions[0]
+ self.out_features = dimensions[-1]
+
+ # Code
+ params = {}
+ graph = fx.Graph()
+ tracer = fx.proxy.GraphAppendingTracer(graph)
+
+ def Proxy(n):
+ return fx.Proxy(n, tracer=tracer)
+
+ features = Proxy(graph.placeholder("x"))
+ norm_from_last: float = 1.0
+
+ base = torch.nn.Module()
+
+ for layer, (h_in, h_out) in enumerate(zip(dimensions, dimensions[1:])):
+ # do dropout
+ if mlp_dropout_p > 0:
+ # only dropout if it will do something
+ # dropout before linear projection- https://stats.stackexchange.com/a/245137
+ features = Proxy(graph.call_module("_dropout", (features.node,)))
+
+ # make weights
+ w = torch.empty(h_in, h_out)
+
+ if mlp_initialization == "normal":
+ w.normal_()
+ elif mlp_initialization == "uniform":
+ # these values give < x^2 > = 1
+ w.uniform_(-math.sqrt(3), math.sqrt(3))
+ elif mlp_initialization == "orthogonal":
+ # this rescaling gives < x^2 > = 1
+ torch.nn.init.orthogonal_(w, gain=math.sqrt(max(w.shape)))
+ else:
+ raise NotImplementedError(
+ f"Invalid mlp_initialization {mlp_initialization}"
+ )
+
+ # generate code
+ params[f"_weight_{layer}"] = w
+ w = Proxy(graph.get_attr(f"_weight_{layer}"))
+ w = w * (
+ norm_from_last / math.sqrt(float(h_in))
+ ) # include any nonlinearity normalization from previous layers
+ features = torch.matmul(features, w)
+
+ if mlp_batchnorm:
+ # if we call batchnorm, do it after the nonlinearity
+ features = Proxy(graph.call_module(f"_bn_{layer}", (features.node,)))
+ setattr(base, f"_bn_{layer}", torch.nn.BatchNorm1d(h_out))
+
+ # generate nonlinearity code
+ if nonlinearity is not None and layer < num_layers - 1:
+ features = nonlinearity(features)
+ # add the normalization const in next layer
+ norm_from_last = nonlin_const
+
+ graph.output(features.node)
+
+ for pname, p in params.items():
+ setattr(base, pname, torch.nn.Parameter(p))
+
+ if mlp_dropout_p > 0:
+ # with normal dropout everything blows up
+ base._dropout = torch.nn.AlphaDropout(p=mlp_dropout_p)
+
+ self._codegen_register({"_forward": fx.GraphModule(base, graph)})
+
+ def forward(self, x):
+ return self._forward(x)
+
+class InitLayer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ idp,
+ num_types: int,
+ n_radial_basis: int,
+ r_max: float,
+ irreps_sh: o3.Irreps=None,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ two_body_latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ device: Union[str, torch.device] = torch.device("cpu"),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ ):
+ super(InitLayer, self).__init__()
+ SCALAR = o3.Irrep("0e")
+ self.num_types = num_types
+ if isinstance(r_max, float) or isinstance(r_max, int):
+ self.r_max = torch.tensor(r_max, device=device, dtype=dtype)
+ self.r_max_dict = None
+ elif isinstance(r_max, dict):
+ c_set = set(list(r_max.values()))
+ self.r_max = torch.tensor(max(list(r_max.values())), device=device, dtype=dtype)
+ if len(r_max) == 1 or len(c_set) == 1:
+ self.r_max_dict = None
+ else:
+ self.r_max_dict = {}
+ for k,v in r_max.items():
+ self.r_max_dict[k] = torch.tensor(v, device=device, dtype=dtype)
+ else:
+ raise TypeError("r_max should be either float, int or dict")
+
+ self.idp = idp
+ self.two_body_latent_kwargs = two_body_latent_kwargs
+ self.r_start_cos_ratio = r_start_cos_ratio
+ self.polynomial_cutoff_p = PolynomialCutoff_p
+ self.cutoff_type = cutoff_type
+ self.device = device
+ self.dtype = dtype
+ self.irreps_out = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+ # env_embed_irreps = o3.Irreps([(1, ir) for _, ir in irreps_sh])
+ assert (
+ irreps_sh[0].ir == SCALAR
+ ), "env_embed_irreps must start with scalars"
+
+ # Node invariants for center and neighbor (chemistry)
+ # Plus edge invariants for the edge (radius).
+ self.two_body_latent = ScalarMLPFunction(
+ mlp_input_dimension=(2 * num_types + n_radial_basis),
+ mlp_output_dimension=None,
+ **two_body_latent_kwargs,
+ )
+
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=self.irreps_out,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element", # if path normalization is element and input irreps has 1 mul, it should not have effect !
+ )
+
+ # self.bn = BatchNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # instance=False,
+ # normalization="component",
+ # )
+
+ self.env_embed_mlp = ScalarMLPFunction(
+ mlp_input_dimension=self.two_body_latent.out_features,
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ **env_embed_kwargs,
+ )
+ self.bessel = BesselBasis(r_max=self.r_max, num_basis=n_radial_basis, trainable=True)
+
+
+
+ def forward(self, edge_index, bond_type, edge_sh, edge_length, node_one_hot):
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ edge_invariants = self.bessel(edge_length)
+ node_invariants = node_one_hot
+
+ # Vectorized precompute per layer cutoffs
+ if self.r_max_dict is None:
+ if self.cutoff_type == "cosine":
+ cutoff_coeffs = cosine_cutoff(
+ edge_length,
+ self.r_max.reshape(-1),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+
+ elif self.cutoff_type == "polynomial":
+ cutoff_coeffs = polynomial_cutoff(
+ edge_length, self.r_max.reshape(-1), p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+ else:
+ cutoff_coeffs = torch.zeros(edge_index.shape[1], dtype=self.dtype, device=self.device)
+
+ for bond, ty in self.idp.bond_to_type.items():
+ mask = bond_type == ty
+ index = mask.nonzero().squeeze(-1)
+
+ if mask.any():
+ iatom, jatom = bond.split("-")
+ if self.cutoff_type == "cosine":
+ c_coeff = cosine_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+ elif self.cutoff_type == "polynomial":
+ c_coeff = polynomial_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+
+ cutoff_coeffs = torch.index_copy(cutoff_coeffs, 0, index, c_coeff)
+
+ # Determine which edges are still in play
+ prev_mask = cutoff_coeffs > 0
+ active_edges = (cutoff_coeffs > 0).nonzero().squeeze(-1)
+
+ # Compute latents
+ latents = torch.zeros(
+ (edge_sh.shape[0], self.two_body_latent.out_features),
+ dtype=edge_sh.dtype,
+ device=edge_sh.device,
+ )
+
+ new_latents = self.two_body_latent(torch.cat([
+ node_invariants[edge_center],
+ node_invariants[edge_neighbor],
+ edge_invariants,
+ ], dim=-1)[prev_mask])
+
+ # Apply cutoff, which propagates through to everything else
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+ weights = self.env_embed_mlp(latents[active_edges])
+
+ # embed initial edge
+ features = self._env_weighter(
+ edge_sh[prev_mask], weights
+ ) # features is edge_attr
+ # features = self.bn(features)
+
+ return latents, features, cutoff_coeffs, active_edges # the radial embedding x and the sperical hidden V
+
+class Layer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ num_types: int,
+ avg_num_neighbors: Optional[float] = None,
+ irreps_sh: o3.Irreps=None,
+ irreps_in: o3.Irreps=None,
+ irreps_out: o3.Irreps=None,
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_in: int=1024,
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ ):
+ super().__init__()
+ SCALAR = o3.Irrep("0e")
+ self.latent_resnet = latent_resnet
+ self.avg_num_neighbors = avg_num_neighbors
+ self.linear_after_env_embed = linear_after_env_embed
+ self.irreps_in = irreps_in
+ self.irreps_out = irreps_out
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+
+ # for normalization of env embed sums
+ # one per layer
+ self.register_buffer(
+ "env_sum_normalizations",
+ # dividing by sqrt(N)
+ torch.as_tensor(avg_num_neighbors).rsqrt(),
+ )
+
+ latent = functools.partial(ScalarMLPFunction, **latent_kwargs)
+
+ self.latents = None
+ self.env_embed_mlps = None
+ self.tps = None
+ self.linears = None
+ self.env_linears = None
+
+ # Prune impossible paths
+ self.irreps_out = o3.Irreps(
+ [
+ (mul, ir)
+ for mul, ir in self.irreps_out
+ if tp_path_exists(irreps_sh, irreps_in, ir)
+ ]
+ )
+
+ mul_irreps_sh = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=mul_irreps_sh,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element",
+ )
+
+ # == Remove unneeded paths ==
+ #TODO: add the remove unseen paths
+
+ if self.linear_after_env_embed:
+ self.env_linears = Linear(
+ mul_irreps_sh,
+ mul_irreps_sh,
+ shared_weights=True,
+ internal_weights=True,
+ )
+ else:
+ self.env_linears = torch.nn.Identity()
+
+ # Make TP
+ tmp_i_out: int = 0
+ instr = []
+ n_scalar_outs: int = 0
+ n_scalar_mul = []
+ full_out_irreps = []
+ for i_out, (mul_out, ir_out) in enumerate(self.irreps_out):
+ for i_1, (mul1, ir_1) in enumerate(self.irreps_in): # what if feature_irreps_in has mul?
+ for i_2, (mul2, ir_2) in enumerate(self._env_weighter.irreps_out+self._env_weighter.irreps_out):
+ if ir_out in ir_1 * ir_2:
+ if ir_out == SCALAR:
+ n_scalar_outs += 1
+ n_scalar_mul.append(mul2)
+ # assert mul_out == mul1 == mul2
+ instr.append((i_1, i_2, tmp_i_out, 'uvv', True))
+ full_out_irreps.append((mul2, ir_out))
+ assert full_out_irreps[-1][0] == mul2
+ tmp_i_out += 1
+ full_out_irreps = o3.Irreps(full_out_irreps)
+ assert all(ir == SCALAR for _, ir in full_out_irreps[:n_scalar_outs])
+ self.n_scalar_mul = sum(n_scalar_mul)
+
+ self.lin_pre = Linear(
+ irreps_in=self.irreps_in,
+ irreps_out=self.irreps_in,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ self.tp = TensorProduct(
+ irreps_in1=o3.Irreps(
+ [(mul, ir) for mul, ir in self.irreps_in]
+ ),
+ irreps_in2=o3.Irreps(
+ [(mul, ir) for mul, ir in self._env_weighter.irreps_out+self._env_weighter.irreps_out]
+ ),
+ irreps_out=o3.Irreps(
+ [(mul, ir) for mul, ir in full_out_irreps]
+ ),
+ irrep_normalization="component",
+ instructions=instr,
+ shared_weights=True,
+ internal_weights=True,
+ )
+
+
+
+ # self.sc = FullyConnectedTensorProduct(
+ # irreps_in,
+ # o3.Irreps(str(2*num_types)+"x0e"),
+ # self.irreps_out,
+ # shared_weights=True,
+ # internal_weights=True
+ # )
+
+ self.lin_post = Linear(
+ self.irreps_out,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ self.bn = BatchNorm(
+ irreps=self.irreps_out,
+ affine=True,
+ instance=False,
+ normalization="component",
+ )
+
+ self.linear_res = Linear(
+ self.irreps_in,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # build activation
+ irreps_scalar = o3.Irreps(str(self.irreps_out[0]))
+ irreps_gated = o3.Irreps([(mul, ir) for mul, ir in self.irreps_out if ir.l > 0]).simplify()
+ irreps_gates = o3.Irreps([(mul, (0,1)) for mul, _ in irreps_gated]).simplify()
+ act={1: torch.nn.functional.silu, -1: torch.tanh}
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+
+ self.activation = Gate(
+ irreps_scalar, [act[ir.p] for _, ir in irreps_scalar], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ # we extract the scalars from the first irrep of the tp
+ assert self.irreps_out[0].ir == SCALAR
+ self.linears = Linear(
+ irreps_in=full_out_irreps,
+ irreps_out=self.activation.irreps_in,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # the embedded latent invariants from the previous layer(s)
+ # and the invariants extracted from the last layer's TP:
+ self.latents = latent(
+ mlp_input_dimension=latent_in+self.n_scalar_mul,
+ mlp_output_dimension=None,
+ )
+
+ # the env embed MLP takes the last latent's output as input
+ # and outputs enough weights for the env embedder
+ self.env_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ )
+ # - layer resnet update weights -
+ if latent_resnet_update_ratios is None:
+ # We initialize to zeros, which under the sigmoid() become 0.5
+ # so 1/2 * layer_1 + 1/4 * layer_2 + ...
+ # note that the sigmoid of these are the factor _between_ layers
+ # so the first entry is the ratio for the latent resnet of the first and second layers, etc.
+ # e.g. if there are 3 layers, there are 2 ratios: l1:l2, l2:l3
+ latent_resnet_update_params = torch.zeros(1)
+ else:
+ latent_resnet_update_ratios = torch.as_tensor(
+ latent_resnet_update_ratios, dtype=torch.get_default_dtype()
+ )
+ assert latent_resnet_update_ratios > 0.0
+ assert latent_resnet_update_ratios < 1.0
+ latent_resnet_update_params = torch.special.logit(
+ latent_resnet_update_ratios
+ )
+ # The sigmoid is mostly saturated at ±6, keep it in a reasonable range
+ latent_resnet_update_params.clamp_(-6.0, 6.0)
+
+ if latent_resnet_update_ratios_learnable:
+ self._latent_resnet_update_params = torch.nn.Parameter(
+ latent_resnet_update_params
+ )
+ else:
+ self.register_buffer(
+ "_latent_resnet_update_params", latent_resnet_update_params
+ )
+
+ def forward(self, edge_index, edge_sh, atom_type, latents, features, cutoff_coeffs, active_edges):
+ # update V
+ # update X
+ # edge_index: [2, num_edges]
+ # irreps_sh: [num_edges, irreps_sh]
+ # latents: [num_edges, latent_in]
+ # fetures: [num_active_edges, in_irreps]
+ # cutoff_coeffs: [num_edges]
+ # active_edges: [num_active_edges]
+
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ prev_mask = cutoff_coeffs > 0
+
+ # sc_features = self.sc(features, node_one_hot[edge_index].transpose(0,1).flatten(1,2)[active_edges])
+ # update V
+ weights = self.env_embed_mlps(latents[active_edges])
+
+ # Build the local environments
+ # This local environment should only be a sum over neighbors
+ # who are within the cutoff of the _current_ layer
+ # Those are the active edges, which are the only ones we
+ # have weights for (env_w) anyway.
+ # So we mask out the edges in the sum:
+ local_env_per_edge = scatter(
+ self._env_weighter(edge_sh[active_edges], weights),
+ edge_center[active_edges],
+ dim=0,
+ )
+
+ # currently, we have a sum over neighbors of constant number for each layer,
+ # the env_sum_normalization can be a scalar or list
+ # the different cutoff can be added in the future
+
+ if self.env_sum_normalizations.ndim < 1:
+ norm_const = self.env_sum_normalizations
+ else:
+ norm_const = self.env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ local_env_per_edge = local_env_per_edge * norm_const
+ local_env_per_edge = self.env_linears(local_env_per_edge)
+
+ # local_env_per_edge = torch.cat([local_env_per_edge[edge_center[active_edges]], local_env_per_edge[edge_neighbor[active_edges]]], dim=-1)
+ local_env_per_edge = local_env_per_edge[edge_center[active_edges]]
+ # Now do the TP
+ # recursively tp current features with the environment embeddings
+ new_features = self.tp(
+ self.lin_pre(features),
+ torch.cat(
+ [
+ local_env_per_edge[edge_center[active_edges]],
+ local_env_per_edge[edge_neighbor[active_edges]]
+ ], dim=-1
+ )) # full_out_irreps
+
+
+ # features has shape [N_edge, full_feature_out.dim]
+ # we know scalars are first
+ scalars = new_features[:, :self.n_scalar_mul]
+ assert len(scalars.shape) == 2
+
+ # do the linear
+ new_features = self.linears(new_features)
+ new_features = self.activation(new_features)
+
+ new_features = self.lin_post(new_features)
+
+ new_features = self.bn(new_features)
+
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ features = coefficient_new * new_features + coefficient_old * self.linear_res(features)
+ else:
+ features = new_features
+
+ # update X
+ latent_inputs_to_cat = [
+ latents[active_edges],
+ scalars,
+ ]
+
+ new_latents = self.latents(torch.cat(latent_inputs_to_cat, dim=-1))
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ # At init, we assume new and old to be approximately uncorrelated
+ # Thus their variances add
+ # we always want the latent space to be normalized to variance = 1.0,
+ # because it is critical for learnability. Still, we want to preserve
+ # the _relative_ magnitudes of the current latent and the residual update
+ # to be controled by `this_layer_update_coeff`
+ # Solving the simple system for the two coefficients:
+ # a^2 + b^2 = 1 (variances add) & a * this_layer_update_coeff = b
+ # gives
+ # a = 1 / sqrt(1 + this_layer_update_coeff^2) & b = this_layer_update_coeff / sqrt(1 + this_layer_update_coeff^2)
+ # rsqrt is reciprocal sqrt
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ latents = torch.index_add(
+ coefficient_old * latents,
+ 0,
+ active_edges,
+ coefficient_new * new_latents,
+ )
+ else:
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+
+ return latents, features, cutoff_coeffs, active_edges
\ No newline at end of file
diff --git a/dptb/nn/embedding/e3baseline_local.py b/dptb/nn/embedding/e3baseline_local.py
new file mode 100644
index 00000000..f4971ad6
--- /dev/null
+++ b/dptb/nn/embedding/e3baseline_local.py
@@ -0,0 +1,1056 @@
+from typing import Optional, List, Union, Dict
+import math
+import functools
+import warnings
+
+import torch
+from torch_runstats.scatter import scatter
+
+from torch import fx
+from e3nn.util.codegen import CodeGenMixin
+from dptb.nn.norm import TypeNorm
+from e3nn import o3
+from e3nn.nn import Gate
+from e3nn.nn._batchnorm import BatchNorm
+from torch_scatter import scatter_mean
+from e3nn.o3 import Linear, SphericalHarmonics
+from e3nn.math import normalize2mom
+from e3nn.util.jit import compile_mode
+from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift
+
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..radial_basis import BesselBasis
+from dptb.nn.embedding.from_deephe3.deephe3 import tp_path_exists
+from dptb.nn.embedding.from_deephe3.e3module import SeparateWeightTensorProduct
+from dptb.data import _keys
+from dptb.nn.cutoff import cosine_cutoff, polynomial_cutoff
+from dptb.nn.rescale import E3ElementLinear
+import math
+from dptb.data.transforms import OrbitalMapper
+from ..type_encode.one_hot import OneHotAtomEncoding
+from dptb.data.AtomicDataDict import with_edge_vectors, with_env_vectors, with_batch
+
+from math import ceil
+
+@Embedding.register("e3baseline_local")
+class E3BaseLineModelLocal(torch.nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ # required params
+ n_atom: int=1,
+ n_layers: int=3,
+ n_radial_basis: int=10,
+ r_max: float=5.0,
+ lmax: int=4,
+ irreps_hidden: o3.Irreps=None,
+ avg_num_neighbors: Optional[float] = None,
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ sh_normalized: bool = True,
+ sh_normalization: str = "component",
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [256, 256, 512],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+
+ super(E3BaseLineModelLocal, self).__init__()
+
+ irreps_hidden = o3.Irreps(irreps_hidden)
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb")
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.idp.get_irreps(no_parity=False)
+ self.n_atom = n_atom
+
+ irreps_sh=o3.Irreps([(1, (i, (-1) ** i)) for i in range(lmax + 1)])
+ orbpair_irreps = self.idp.orbpair_irreps.sort()[0].simplify()
+
+ # check if the irreps setting satisfied the requirement of idp
+ irreps_out = []
+ for mul, ir1 in irreps_hidden:
+ for _, ir2 in orbpair_irreps:
+ irreps_out += [o3.Irrep(str(irr)) for irr in ir1*ir2]
+ irreps_out = o3.Irreps(irreps_out).sort()[0].simplify()
+
+ assert all(ir in irreps_out for _, ir in orbpair_irreps), "hidden irreps should at least cover all the reqired irreps in the hamiltonian data {}".format(orbpair_irreps)
+
+ # TODO: check if the tp in first layer can produce the required irreps for hidden states
+
+ self.sh = SphericalHarmonics(
+ irreps_sh, sh_normalized, sh_normalization
+ )
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+
+ self.init_layer = InitLayer(
+ idp=self.idp,
+ num_types=n_atom,
+ n_radial_basis=n_radial_basis,
+ r_max=r_max,
+ irreps_sh=irreps_sh,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ two_body_latent_kwargs=latent_kwargs,
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio=r_start_cos_ratio,
+ PolynomialCutoff_p=PolynomialCutoff_p,
+ cutoff_type=cutoff_type,
+ device=device,
+ dtype=dtype,
+ )
+
+ self.layers = torch.nn.ModuleList()
+ latent_in =latent_kwargs["mlp_latent_dimensions"][-1]
+ # actually, we can derive the least required irreps_in and out from the idp's node and pair irreps
+ last_layer = False
+ for i in range(n_layers):
+ if i == 0:
+ irreps_in = self.init_layer.irreps_out
+ else:
+ irreps_in = irreps_hidden
+
+ if i == n_layers - 1:
+ irreps_out = orbpair_irreps.sort()[0].simplify()
+ last_layer = True
+ else:
+ irreps_out = irreps_hidden
+
+ self.layers.append(Layer(
+ num_types=n_atom,
+ avg_num_neighbors=avg_num_neighbors,
+ irreps_sh=irreps_sh,
+ irreps_in=irreps_in,
+ irreps_out=irreps_out,
+ # general hyperparameters:
+ linear_after_env_embed=linear_after_env_embed,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ latent_kwargs=latent_kwargs,
+ latent_in=latent_in,
+ latent_resnet=latent_resnet,
+ latent_resnet_update_ratios=latent_resnet_update_ratios,
+ latent_resnet_update_ratios_learnable=latent_resnet_update_ratios_learnable,
+ last_layer=last_layer,
+ dtype=dtype,
+ device=device,
+ )
+ )
+
+ # initilize output_layer
+ self.out_edge = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+ self.out_node = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+
+ @property
+ def out_edge_irreps(self):
+ return self.idp.orbpair_irreps
+
+ @property
+ def out_node_irreps(self):
+ return self.idp.orbpair_irreps
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = with_edge_vectors(data, with_lengths=True)
+ # data = with_env_vectors(data, with_lengths=True)
+ data = with_batch(data)
+ batch = data[_keys.BATCH_KEY]
+
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_sh = self.sh(data[_keys.EDGE_VECTORS_KEY][:,[1,2,0]])
+ edge_length = data[_keys.EDGE_LENGTH_KEY]
+
+
+ data = self.onehot(data)
+ node_one_hot = data[_keys.NODE_ATTRS_KEY]
+ atom_type = data[_keys.ATOM_TYPE_KEY].flatten()
+ bond_type = data[_keys.EDGE_TYPE_KEY].flatten()
+ latents, features, cutoff_coeffs, active_edges = self.init_layer(edge_index, bond_type, edge_sh, edge_length, node_one_hot)
+
+ for layer in self.layers:
+ latents, features, cutoff_coeffs, active_edges = layer(edge_index, edge_sh, atom_type, bond_type, latents, features, cutoff_coeffs, active_edges, batch)
+
+ data[_keys.NODE_FEATURES_KEY] = self.out_node(latents)
+ data[_keys.EDGE_FEATURES_KEY] = torch.zeros(edge_index.shape[1], self.idp.orbpair_irreps.dim, dtype=self.dtype, device=self.device)
+ data[_keys.EDGE_FEATURES_KEY] = torch.index_copy(data[_keys.EDGE_FEATURES_KEY], 0, active_edges, self.out_edge(features))
+
+ return data
+
+def tp_path_exists(irreps_in1, irreps_in2, ir_out):
+ irreps_in1 = o3.Irreps(irreps_in1).simplify()
+ irreps_in2 = o3.Irreps(irreps_in2).simplify()
+ ir_out = o3.Irrep(ir_out)
+
+ for _, ir1 in irreps_in1:
+ for _, ir2 in irreps_in2:
+ if ir_out in ir1 * ir2:
+ return True
+ return False
+
+def get_gate_nonlin(irreps_in1, irreps_in2, irreps_out,
+ act={1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+ ):
+ # get gate nonlinearity after tensor product
+ # irreps_in1 and irreps_in2 are irreps to be multiplied in tensor product
+ # irreps_out is desired irreps after gate nonlin
+ # notice that nonlin.irreps_out might not be exactly equal to irreps_out
+
+ irreps_scalars = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l == 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ irreps_gated = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l > 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ if irreps_gated.dim > 0:
+ if tp_path_exists(irreps_in1, irreps_in2, "0e"):
+ ir = "0e"
+ elif tp_path_exists(irreps_in1, irreps_in2, "0o"):
+ ir = "0o"
+ warnings.warn('Using odd representations as gates')
+ else:
+ raise ValueError(
+ f"irreps_in1={irreps_in1} times irreps_in2={irreps_in2} is unable to produce gates needed for irreps_gated={irreps_gated}")
+ else:
+ ir = None
+ irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()
+
+ gate_nonlin = Gate(
+ irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ return gate_nonlin
+
+
+@compile_mode("script")
+class MakeWeightedChannels(torch.nn.Module):
+ weight_numel: int
+ multiplicity_out: Union[int, list]
+ _num_irreps: int
+
+ def __init__(
+ self,
+ irreps_in: o3.Irreps,
+ multiplicity_out: Union[int, list],
+ pad_to_alignment: int = 1,
+ ):
+ super().__init__()
+ assert all(mul == 1 for mul, _ in irreps_in)
+ assert multiplicity_out >= 1
+ # Each edgewise output multiplicity is a per-irrep weighted sum over the input
+ # So we need to apply the weight for the ith irrep to all DOF in that irrep
+ w_index = []
+ idx = 0
+ self._num_irreps = 0
+ for (mul, ir) in irreps_in:
+ w_index += sum(([ix] * ir.dim for ix in range(idx, idx + mul)), [])
+ idx += mul
+ self._num_irreps += mul
+ # w_index = sum(([i] * ir.dim for i, (mul, ir) in enumerate(irreps_in)), [])
+ # pad to padded length
+ n_pad = (
+ int(ceil(irreps_in.dim / pad_to_alignment)) * pad_to_alignment
+ - irreps_in.dim
+ )
+ # use the last weight, what we use doesn't matter much
+ w_index += [w_index[-1]] * n_pad
+ self.register_buffer("_w_index", torch.as_tensor(w_index, dtype=torch.long))
+ # there is
+ self.multiplicity_out = multiplicity_out
+ self.weight_numel = self._num_irreps * multiplicity_out
+
+ def forward(self, edge_attr, weights):
+ # weights are [z, u, num_i]
+ # edge_attr are [z, i]
+ # i runs over all irreps, which is why the weights need
+ # to be indexed in order to go from [num_i] to [i]
+ return torch.einsum(
+ "zi,zui->zui",
+ edge_attr,
+ weights.view(
+ -1,
+ self.multiplicity_out,
+ self._num_irreps,
+ )[:, :, self._w_index],
+ )
+
+@torch.jit.script
+def ShiftedSoftPlus(x: torch.Tensor):
+ return torch.nn.functional.softplus(x) - math.log(2.0)
+
+class ScalarMLPFunction(CodeGenMixin, torch.nn.Module):
+ """Module implementing an MLP according to provided options."""
+
+ in_features: int
+ out_features: int
+
+ def __init__(
+ self,
+ mlp_input_dimension: Optional[int],
+ mlp_latent_dimensions: List[int],
+ mlp_output_dimension: Optional[int],
+ mlp_nonlinearity: Optional[str] = "silu",
+ mlp_initialization: str = "normal",
+ mlp_dropout_p: float = 0.0,
+ mlp_batchnorm: bool = False,
+ ):
+ super().__init__()
+ nonlinearity = {
+ None: None,
+ "silu": torch.nn.functional.silu,
+ "ssp": ShiftedSoftPlus,
+ }[mlp_nonlinearity]
+ if nonlinearity is not None:
+ nonlin_const = normalize2mom(nonlinearity).cst
+ else:
+ nonlin_const = 1.0
+
+ dimensions = (
+ ([mlp_input_dimension] if mlp_input_dimension is not None else [])
+ + mlp_latent_dimensions
+ + ([mlp_output_dimension] if mlp_output_dimension is not None else [])
+ )
+ assert len(dimensions) >= 2 # Must have input and output dim
+ num_layers = len(dimensions) - 1
+
+ self.in_features = dimensions[0]
+ self.out_features = dimensions[-1]
+
+ # Code
+ params = {}
+ graph = fx.Graph()
+ tracer = fx.proxy.GraphAppendingTracer(graph)
+
+ def Proxy(n):
+ return fx.Proxy(n, tracer=tracer)
+
+ features = Proxy(graph.placeholder("x"))
+ norm_from_last: float = 1.0
+
+ base = torch.nn.Module()
+
+ for layer, (h_in, h_out) in enumerate(zip(dimensions, dimensions[1:])):
+ # do dropout
+ if mlp_dropout_p > 0:
+ # only dropout if it will do something
+ # dropout before linear projection- https://stats.stackexchange.com/a/245137
+ features = Proxy(graph.call_module("_dropout", (features.node,)))
+
+ # make weights
+ w = torch.empty(h_in, h_out)
+
+ if mlp_initialization == "normal":
+ w.normal_()
+ elif mlp_initialization == "uniform":
+ # these values give < x^2 > = 1
+ w.uniform_(-math.sqrt(3), math.sqrt(3))
+ elif mlp_initialization == "orthogonal":
+ # this rescaling gives < x^2 > = 1
+ torch.nn.init.orthogonal_(w, gain=math.sqrt(max(w.shape)))
+ else:
+ raise NotImplementedError(
+ f"Invalid mlp_initialization {mlp_initialization}"
+ )
+
+ # generate code
+ params[f"_weight_{layer}"] = w
+ w = Proxy(graph.get_attr(f"_weight_{layer}"))
+ w = w * (
+ norm_from_last / math.sqrt(float(h_in))
+ ) # include any nonlinearity normalization from previous layers
+ features = torch.matmul(features, w)
+
+ if mlp_batchnorm:
+ # if we call batchnorm, do it after the nonlinearity
+ features = Proxy(graph.call_module(f"_bn_{layer}", (features.node,)))
+ setattr(base, f"_bn_{layer}", torch.nn.BatchNorm1d(h_out))
+
+ # generate nonlinearity code
+ if nonlinearity is not None and layer < num_layers - 1:
+ features = nonlinearity(features)
+ # add the normalization const in next layer
+ norm_from_last = nonlin_const
+
+ graph.output(features.node)
+
+ for pname, p in params.items():
+ setattr(base, pname, torch.nn.Parameter(p))
+
+ if mlp_dropout_p > 0:
+ # with normal dropout everything blows up
+ base._dropout = torch.nn.AlphaDropout(p=mlp_dropout_p)
+
+ self._codegen_register({"_forward": fx.GraphModule(base, graph)})
+
+ def forward(self, x):
+ return self._forward(x)
+
+class InitLayer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ idp,
+ num_types: int,
+ n_radial_basis: int,
+ r_max: float,
+ irreps_sh: o3.Irreps=None,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ two_body_latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ device: Union[str, torch.device] = torch.device("cpu"),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ ):
+ super(InitLayer, self).__init__()
+ SCALAR = o3.Irrep("0e")
+ self.num_types = num_types
+ if isinstance(r_max, float) or isinstance(r_max, int):
+ self.r_max = torch.tensor(r_max, device=device, dtype=dtype)
+ self.r_max_dict = None
+ elif isinstance(r_max, dict):
+ c_set = set(list(r_max.values()))
+ self.r_max = torch.tensor(max(list(r_max.values())), device=device, dtype=dtype)
+ if len(r_max) == 1 or len(c_set) == 1:
+ self.r_max_dict = None
+ else:
+ self.r_max_dict = {}
+ for k,v in r_max.items():
+ self.r_max_dict[k] = torch.tensor(v, device=device, dtype=dtype)
+ else:
+ raise TypeError("r_max should be either float, int or dict")
+
+ self.idp = idp
+ self.two_body_latent_kwargs = two_body_latent_kwargs
+ self.r_start_cos_ratio = r_start_cos_ratio
+ self.polynomial_cutoff_p = PolynomialCutoff_p
+ self.cutoff_type = cutoff_type
+ self.device = device
+ self.dtype = dtype
+ self.irreps_out = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+ # env_embed_irreps = o3.Irreps([(1, ir) for _, ir in irreps_sh])
+ assert (
+ irreps_sh[0].ir == SCALAR
+ ), "env_embed_irreps must start with scalars"
+
+ # Node invariants for center and neighbor (chemistry)
+ # Plus edge invariants for the edge (radius).
+ self.two_body_latent = ScalarMLPFunction(
+ mlp_input_dimension=(2 * num_types + n_radial_basis),
+ mlp_output_dimension=None,
+ **two_body_latent_kwargs,
+ )
+
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=self.irreps_out,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element", # if path normalization is element and input irreps has 1 mul, it should not have effect !
+ )
+
+ # self.bn = BatchNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # instance=False,
+ # normalization="component",
+ # )
+
+ self.env_embed_mlp = ScalarMLPFunction(
+ mlp_input_dimension=self.two_body_latent.out_features,
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ **env_embed_kwargs,
+ )
+
+ self.bessel = BesselBasis(r_max=self.r_max, num_basis=n_radial_basis, trainable=True)
+
+
+
+ def forward(self, edge_index, bond_type, edge_sh, edge_length, node_one_hot):
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ edge_invariants = self.bessel(edge_length)
+ node_invariants = node_one_hot
+
+ # Vectorized precompute per layer cutoffs
+ if self.r_max_dict is None:
+ if self.cutoff_type == "cosine":
+ cutoff_coeffs = cosine_cutoff(
+ edge_length,
+ self.r_max.reshape(-1),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+
+ elif self.cutoff_type == "polynomial":
+ cutoff_coeffs = polynomial_cutoff(
+ edge_length, self.r_max.reshape(-1), p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+ else:
+ cutoff_coeffs = torch.zeros(edge_index.shape[1], dtype=self.dtype, device=self.device)
+
+ for bond, ty in self.idp.bond_to_type.items():
+ mask = bond_type == ty
+ index = mask.nonzero().squeeze(-1)
+
+ if mask.any():
+ iatom, jatom = bond.split("-")
+ if self.cutoff_type == "cosine":
+ c_coeff = cosine_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+ elif self.cutoff_type == "polynomial":
+ c_coeff = polynomial_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+
+ cutoff_coeffs = torch.index_copy(cutoff_coeffs, 0, index, c_coeff)
+
+ # Determine which edges are still in play
+ prev_mask = cutoff_coeffs > 0
+ active_edges = (cutoff_coeffs > 0).nonzero().squeeze(-1)
+
+ # Compute latents
+ latents = torch.zeros(
+ (edge_sh.shape[0], self.two_body_latent.out_features),
+ dtype=edge_sh.dtype,
+ device=edge_sh.device,
+ )
+
+ new_latents = self.two_body_latent(torch.cat([
+ node_invariants[edge_center],
+ node_invariants[edge_neighbor],
+ edge_invariants,
+ ], dim=-1)[prev_mask])
+
+ # Apply cutoff, which propagates through to everything else
+ latents = torch.index_copy(
+ latents, 0, active_edges,
+ cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ )
+ weights = self.env_embed_mlp(new_latents)
+
+ # embed initial edge
+ features = self._env_weighter(
+ edge_sh[prev_mask], weights
+ ) # features is edge_attr
+ # features = self.bn(features)
+
+ return latents, features, cutoff_coeffs, active_edges # the radial embedding x and the sperical hidden V
+
+class Layer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ num_types: int,
+ avg_num_neighbors: Optional[float] = None,
+ irreps_sh: o3.Irreps=None,
+ irreps_in: o3.Irreps=None,
+ irreps_out: o3.Irreps=None,
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_in: int=1024,
+ latent_resnet: bool = True,
+ last_layer: bool = False,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+ super().__init__()
+
+ assert latent_in == latent_kwargs["mlp_latent_dimensions"][-1]
+
+ SCALAR = o3.Irrep("0e")
+ self.latent_resnet = latent_resnet
+ self.avg_num_neighbors = avg_num_neighbors
+ self.linear_after_env_embed = linear_after_env_embed
+ self.irreps_in = irreps_in
+ self.irreps_out = irreps_out
+ self.last_layer = last_layer
+ self.dtype = dtype
+ self.device = device
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+
+ # for normalization of env embed sums
+ # one per layer
+ self.register_buffer(
+ "env_sum_normalizations",
+ # dividing by sqrt(N)
+ torch.as_tensor(avg_num_neighbors).rsqrt(),
+ )
+
+ latent = functools.partial(ScalarMLPFunction, **latent_kwargs)
+
+ self.latents = None
+ self.env_embed_mlps = None
+ self.tps = None
+ self.linears = None
+ self.env_linears = None
+
+ # Prune impossible paths
+ self.irreps_out = o3.Irreps(
+ [
+ (mul, ir)
+ for mul, ir in self.irreps_out
+ if tp_path_exists(irreps_sh, irreps_in, ir)
+ ]
+ )
+
+ mul_irreps_sh = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=mul_irreps_sh,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element",
+ )
+
+ if last_layer:
+ self._node_weighter = E3ElementLinear(
+ irreps_in=irreps_out,
+ dtype=dtype,
+ device=device,
+ )
+
+ self._edge_weighter = E3ElementLinear(
+ irreps_in=irreps_out,
+ dtype=dtype,
+ device=device,
+ )
+
+ # == Remove unneeded paths ==
+ #TODO: add the remove unseen paths
+
+ if self.linear_after_env_embed:
+ self.env_linears = Linear(
+ mul_irreps_sh,
+ mul_irreps_sh,
+ shared_weights=True,
+ internal_weights=True,
+ )
+
+ else:
+ self.env_linears = torch.nn.Identity()
+
+ # # Make TP
+ # tmp_i_out: int = 0
+ # instr = []
+ # n_scalar_outs: int = 0
+ # n_scalar_mul = []
+ # full_out_irreps = []
+ # for i_out, (mul_out, ir_out) in enumerate(self.irreps_out):
+ # for i_1, (mul1, ir_1) in enumerate(self.irreps_in): # what if feature_irreps_in has mul?
+ # for i_2, (mul2, ir_2) in enumerate(self._env_weighter.irreps_out):
+ # if ir_out in ir_1 * ir_2:
+ # if ir_out == SCALAR:
+ # n_scalar_outs += 1
+ # n_scalar_mul.append(mul2)
+ # # assert mul_out == mul1 == mul2
+ # instr.append((i_1, i_2, tmp_i_out, 'uvv', True))
+ # full_out_irreps.append((mul2, ir_out))
+ # assert full_out_irreps[-1][0] == mul2
+ # tmp_i_out += 1
+ # full_out_irreps = o3.Irreps(full_out_irreps)
+ # assert all(ir == SCALAR for _, ir in full_out_irreps[:n_scalar_outs])
+ # self.n_scalar_mul = sum(n_scalar_mul)
+
+ self.lin_pre = Linear(
+ irreps_in=self.irreps_in,
+ irreps_out=self.irreps_in,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # self.tp = TensorProduct(
+ # irreps_in1=o3.Irreps(
+ # [(mul, ir) for mul, ir in self.irreps_in]
+ # ),
+ # irreps_in2=o3.Irreps(
+ # [(mul, ir) for mul, ir in self._env_weighter.irreps_out]
+ # ),
+ # irreps_out=o3.Irreps(
+ # [(mul, ir) for mul, ir in full_out_irreps]
+ # ),
+ # irrep_normalization="component",
+ # instructions=instr,
+ # shared_weights=True,
+ # internal_weights=True,
+ # )
+ # build activation
+
+ irreps_scalar = o3.Irreps([(mul, ir) for mul, ir in self.irreps_out if ir.l == 0]).simplify()
+ irreps_gated = o3.Irreps([(mul, ir) for mul, ir in self.irreps_out if ir.l > 0]).simplify()
+
+
+ irreps_gates = o3.Irreps([(mul, (0,1)) for mul, _ in irreps_gated]).simplify()
+ act={1: torch.nn.functional.silu, -1: torch.tanh}
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+
+ self.activation = Gate(
+ irreps_scalar, [act[ir.p] for _, ir in irreps_scalar], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ self.tp = SeparateWeightTensorProduct(
+ irreps_in1=self.irreps_in,
+ irreps_in2=self._env_weighter.irreps_out,
+ irreps_out=self.activation.irreps_in,
+ )
+
+ if self.last_layer:
+ self.tp_out = SeparateWeightTensorProduct(
+ irreps_in1=self.irreps_out+self._env_weighter.irreps_out+self._env_weighter.irreps_out,
+ irreps_in2=irreps_sh,
+ irreps_out=self.irreps_out,
+ )
+
+ # self.sc = FullyConnectedTensorProduct(
+ # irreps_in,
+ # o3.Irreps(str(2*num_types)+"x0e"),
+ # self.irreps_out,
+ # shared_weights=True,
+ # internal_weights=True
+ # )
+
+ self.lin_post = Linear(
+ self.activation.irreps_out,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # self.bn = TypeNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # num_type=num_types*num_types,
+ # normalization="component",
+ # )
+
+ # self.bn = BatchNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # normalization="component",
+ # )
+
+ if latent_resnet:
+ self.linear_res = Linear(
+ self.irreps_in,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # we extract the scalars from the first irrep of the tp
+ # assert full_out_irreps[0].ir == SCALAR
+ # self.linears = Linear(
+ # irreps_in=full_out_irreps,
+ # irreps_out=self.activation.irreps_in,
+ # shared_weights=True,
+ # internal_weights=True,
+ # biases=True,
+ # )
+
+ # the embedded latent invariants from the previous layer(s)
+ # and the invariants extracted from the last layer's TP:
+ # we need to make sure all scalars in tp.irreps_out all contains in the first irreps
+ all_tp_scalar = o3.Irreps([(mul, ir) for mul, ir in self.tp.irreps_out if ir.l == 0]).simplify()
+ assert all_tp_scalar.dim == self.tp.irreps_out[0].dim
+ self.latents = latent(
+ mlp_input_dimension=latent_in+self.tp.irreps_out[0].dim,
+ mlp_output_dimension=None,
+ )
+
+ # the env embed MLP takes the last latent's output as input
+ # and outputs enough weights for the env embedder
+ self.env_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ )
+
+ if last_layer:
+ self.node_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._node_weighter.weight_numel,
+ )
+
+ self.edge_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._edge_weighter.weight_numel,
+ )
+
+ # self.node_bn = TypeNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # num_type=num_types,
+ # normalization="component",
+ # )
+
+ # self.node_bn = BatchNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # normalization="norm",
+ # )
+
+ # - layer resnet update weights -
+ if latent_resnet_update_ratios is None:
+ # We initialize to zeros, which under the sigmoid() become 0.5
+ # so 1/2 * layer_1 + 1/4 * layer_2 + ...
+ # note that the sigmoid of these are the factor _between_ layers
+ # so the first entry is the ratio for the latent resnet of the first and second layers, etc.
+ # e.g. if there are 3 layers, there are 2 ratios: l1:l2, l2:l3
+ latent_resnet_update_params = torch.zeros(1)
+ else:
+ latent_resnet_update_ratios = torch.as_tensor(
+ latent_resnet_update_ratios, dtype=torch.get_default_dtype()
+ )
+ assert latent_resnet_update_ratios > 0.0
+ assert latent_resnet_update_ratios < 1.0
+ latent_resnet_update_params = torch.special.logit(
+ latent_resnet_update_ratios
+ )
+ # The sigmoid is mostly saturated at ±6, keep it in a reasonable range
+ latent_resnet_update_params.clamp_(-6.0, 6.0)
+
+ if latent_resnet_update_ratios_learnable:
+ self._latent_resnet_update_params = torch.nn.Parameter(
+ latent_resnet_update_params
+ )
+ else:
+ self.register_buffer(
+ "_latent_resnet_update_params", latent_resnet_update_params
+ )
+
+ def forward(self, edge_index, edge_sh, atom_type, bond_type, latents, features, cutoff_coeffs, active_edges, batch):
+ # update V
+ # update X
+ # edge_index: [2, num_edges]
+ # irreps_sh: [num_edges, irreps_sh]
+ # latents: [num_edges, latent_in]
+ # fetures: [num_active_edges, in_irreps]
+ # cutoff_coeffs: [num_edges]
+ # active_edges: [num_active_edges]
+
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ prev_mask = cutoff_coeffs > 0
+
+ # sc_features = self.sc(features, node_one_hot[edge_index].transpose(0,1).flatten(1,2)[active_edges])
+ # update V
+ weights = self.env_embed_mlps(latents[active_edges])
+
+ # Build the local environments
+ # This local environment should only be a sum over neighbors
+ # who are within the cutoff of the _current_ layer
+ # Those are the active edges, which are the only ones we
+ # have weights for (env_w) anyway.
+ # So we mask out the edges in the sum:
+ local_env_per_edge = scatter(
+ self._env_weighter(edge_sh[active_edges], weights),
+ edge_center[active_edges],
+ dim=0,
+ )
+
+ # currently, we have a sum over neighbors of constant number for each layer,
+ # the env_sum_normalization can be a scalar or list
+ # the different cutoff can be added in the future
+
+ if self.env_sum_normalizations.ndim < 1:
+ norm_const = self.env_sum_normalizations
+ else:
+ norm_const = self.env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ local_env_per_edge = local_env_per_edge * norm_const
+ local_env_per_edge = self.env_linears(local_env_per_edge)
+
+ # local_env_per_edge = torch.cat([local_env_per_edge[edge_center[active_edges]], local_env_per_edge[edge_neighbor[active_edges]]], dim=-1)
+ # local_env_per_edge = local_env_per_edge[edge_center[active_edges]]
+ # Now do the TP
+ # recursively tp current features with the environment embeddings
+ new_features = self.tp(self.lin_pre(features), local_env_per_edge[edge_center[active_edges]]) # full_out_irreps
+
+ scalars = new_features[:, :self.tp.irreps_out[0].dim]
+ new_features = self.activation(new_features)
+ # # do the linear
+ # new_features = self.linears(new_features)
+
+
+ # features has shape [N_edge, full_feature_out.dim]
+ # we know scalars are first
+ assert len(scalars.shape) == 2
+
+ new_features = self.lin_post(new_features)
+
+ # new_features = self.bn(new_features, bond_type[active_edges])
+ # new_features = new_features - scatter_mean(new_features, batch[edge_center[active_edges]], dim=0, dim_size=batch.max()+1)[batch[edge_center[active_edges]]]
+ # new_features = self.bn(new_features)
+
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ features = coefficient_new * new_features + coefficient_old * self.linear_res(features)
+ else:
+ features = new_features
+
+ # whether it is the last layer
+
+ latent_inputs_to_cat = [
+ latents[active_edges],
+ scalars,
+ ]
+
+ new_latents = self.latents(torch.cat(latent_inputs_to_cat, dim=-1))
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ # At init, we assume new and old to be approximately uncorrelated
+ # Thus their variances add
+ # we always want the latent space to be normalized to variance = 1.0,
+ # because it is critical for learnability. Still, we want to preserve
+ # the _relative_ magnitudes of the current latent and the residual update
+ # to be controled by `this_layer_update_coeff`
+ # Solving the simple system for the two coefficients:
+ # a^2 + b^2 = 1 (variances add) & a * this_layer_update_coeff = b
+ # gives
+ # a = 1 / sqrt(1 + this_layer_update_coeff^2) & b = this_layer_update_coeff / sqrt(1 + this_layer_update_coeff^2)
+ # rsqrt is reciprocal sqrt
+
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ latents = torch.index_add(
+ coefficient_old * latents,
+ 0,
+ active_edges,
+ coefficient_new * new_latents,
+ )
+
+ else:
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+
+ if self.last_layer:
+ node_weights = self.node_embed_mlps(latents[active_edges])
+
+ node_features = scatter(
+ self._node_weighter(
+ features,
+ node_weights,
+ ),
+ edge_center[active_edges],
+ dim=0,
+ )
+
+ node_features = node_features * norm_const
+
+ # node_features = self.node_bn(node_features, atom_type)
+ # node_features = self.node_bn(node_features)
+
+ edge_weights = self.edge_embed_mlps(latents[active_edges])
+
+ # the features's inclusion of the radial weight here is the only place
+ # where features are weighted according to the radial distance
+ features = self.tp_out(
+ torch.cat(
+ [
+ features,
+ local_env_per_edge[edge_center[active_edges]],
+ local_env_per_edge[edge_neighbor[active_edges]],
+ ], dim=-1
+ ),
+ edge_sh[active_edges],
+ )
+
+ features = self._edge_weighter(
+ features,
+ edge_weights,
+ )
+
+ return node_features, features, cutoff_coeffs, active_edges
+ else:
+ return latents, features, cutoff_coeffs, active_edges
+
\ No newline at end of file
diff --git a/dptb/nn/embedding/e3baseline_local1.py b/dptb/nn/embedding/e3baseline_local1.py
new file mode 100644
index 00000000..cd126a4b
--- /dev/null
+++ b/dptb/nn/embedding/e3baseline_local1.py
@@ -0,0 +1,1136 @@
+from typing import Optional, List, Union, Dict
+import math
+import functools
+import warnings
+
+import torch
+from torch_runstats.scatter import scatter
+
+from torch import fx
+from e3nn.util.codegen import CodeGenMixin
+from dptb.nn.norm import TypeNorm
+from e3nn import o3
+from e3nn.nn import Gate
+from e3nn.nn._batchnorm import BatchNorm
+from torch_scatter import scatter_mean
+from e3nn.o3 import Linear, SphericalHarmonics
+from e3nn.math import normalize2mom
+from e3nn.util.jit import compile_mode
+from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift
+
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..radial_basis import BesselBasis
+from dptb.nn.embedding.from_deephe3.deephe3 import tp_path_exists
+from dptb.nn.embedding.from_deephe3.e3module import SeparateWeightTensorProduct
+from dptb.data import _keys
+from dptb.nn.cutoff import cosine_cutoff, polynomial_cutoff
+from dptb.nn.rescale import E3ElementLinear
+import math
+from dptb.data.transforms import OrbitalMapper
+from ..type_encode.one_hot import OneHotAtomEncoding
+from dptb.data.AtomicDataDict import with_edge_vectors, with_env_vectors, with_batch
+
+from math import ceil
+
+@Embedding.register("e3baseline_local_wnode")
+class E3BaseLineModelLocal1(torch.nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ # required params
+ n_atom: int=1,
+ n_layers: int=3,
+ n_radial_basis: int=10,
+ r_max: float=5.0,
+ lmax: int=4,
+ irreps_hidden: o3.Irreps=None,
+ avg_num_neighbors: Optional[float] = None,
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ sh_normalized: bool = True,
+ sh_normalization: str = "component",
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [256, 256, 512],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+
+ super(E3BaseLineModelLocal1, self).__init__()
+
+ irreps_hidden = o3.Irreps(irreps_hidden)
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb")
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.idp.get_irreps(no_parity=False)
+ self.n_atom = n_atom
+
+ irreps_sh=o3.Irreps([(1, (i, (-1) ** i)) for i in range(lmax + 1)])
+ orbpair_irreps = self.idp.orbpair_irreps.sort()[0].simplify()
+
+ # check if the irreps setting satisfied the requirement of idp
+ irreps_out = []
+ for mul, ir1 in irreps_hidden:
+ for _, ir2 in orbpair_irreps:
+ irreps_out += [o3.Irrep(str(irr)) for irr in ir1*ir2]
+ irreps_out = o3.Irreps(irreps_out).sort()[0].simplify()
+
+ assert all(ir in irreps_out for _, ir in orbpair_irreps), "hidden irreps should at least cover all the reqired irreps in the hamiltonian data {}".format(orbpair_irreps)
+
+ # TODO: check if the tp in first layer can produce the required irreps for hidden states
+
+ self.sh = SphericalHarmonics(
+ irreps_sh, sh_normalized, sh_normalization
+ )
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+
+ self.init_layer = InitLayer(
+ idp=self.idp,
+ num_types=n_atom,
+ n_radial_basis=n_radial_basis,
+ r_max=r_max,
+ irreps_sh=irreps_sh,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ two_body_latent_kwargs=latent_kwargs,
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio=r_start_cos_ratio,
+ PolynomialCutoff_p=PolynomialCutoff_p,
+ cutoff_type=cutoff_type,
+ device=device,
+ dtype=dtype,
+ )
+
+ self.layers = torch.nn.ModuleList()
+ latent_in =latent_kwargs["mlp_latent_dimensions"][-1]
+ # actually, we can derive the least required irreps_in and out from the idp's node and pair irreps
+ last_layer = False
+ for i in range(n_layers):
+ if i == 0:
+ irreps_in = self.init_layer.irreps_out
+ else:
+ irreps_in = irreps_hidden
+
+ if i == n_layers - 1:
+ irreps_out = orbpair_irreps.sort()[0].simplify()
+ last_layer = True
+ else:
+ irreps_out = irreps_hidden
+
+ self.layers.append(Layer(
+ num_types=n_atom,
+ avg_num_neighbors=avg_num_neighbors,
+ irreps_sh=irreps_sh,
+ irreps_in=irreps_in,
+ irreps_out=irreps_out,
+ # general hyperparameters:
+ linear_after_env_embed=linear_after_env_embed,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ latent_kwargs=latent_kwargs,
+ latent_in=latent_in,
+ latent_resnet=latent_resnet,
+ latent_resnet_update_ratios=latent_resnet_update_ratios,
+ latent_resnet_update_ratios_learnable=latent_resnet_update_ratios_learnable,
+ last_layer=last_layer,
+ dtype=dtype,
+ device=device,
+ )
+ )
+
+ # initilize output_layer
+ sorted_irs = self.idp.orbpair_irreps.sort()[0]
+ irs = self.idp.orbpair_irreps
+ sorted_to_origin = []
+ for ind in self.idp.orbpair_irreps.sort().p:
+ ir = sorted_irs[ind]
+ sorted_to_origin += list(range(sorted_irs[:ind].dim, sorted_irs[:ind].dim+ir.dim))
+ self.sorted_to_origin = torch.LongTensor(sorted_to_origin)
+ self.out_edge = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+ self.out_node_mean = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+ self.out_node_var = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=False)
+ # self.out_node_var_norm = BatchNorm(
+ # irreps=self.out_node_irreps,
+ # affine=True,
+ # normalization="component",
+ # )
+
+ self.out_node_mean_scale = E3PerSpeciesScaleShift(
+ field=_keys.NODE_FEATURES_KEY,
+ num_types=n_atom,
+ irreps_in=self.out_node_irreps,
+ out_field = _keys.NODE_FEATURES_KEY,
+ shifts=0.,
+ scales=1.,
+ dtype=self.dtype,
+ device=self.device,
+ scales_trainable=True,
+ shifts_trainable=True,
+ )
+
+ self.out_node_var_scale = E3PerSpeciesScaleShift(
+ field=_keys.NODE_FEATURES_KEY,
+ num_types=n_atom,
+ irreps_in=self.out_node_irreps,
+ out_field = _keys.NODE_FEATURES_KEY,
+ shifts=None,
+ scales=1.,
+ dtype=self.dtype,
+ device=self.device,
+ scales_trainable=True,
+ shifts_trainable=True,
+ )
+
+ # self.out_edge_scale = E3PerEdgeSpeciesScaleShift(
+ # field=_keys.EDGE_FEATURES_KEY,
+ # num_types=n_atom,
+ # irreps_in=self.out_edge_irreps,
+ # out_field = _keys.EDGE_FEATURES_KEY,
+ # shifts=0.,
+ # scales=1.,
+ # dtype=self.dtype,
+ # device=self.device,
+ # scales_trainable=False,
+ # shifts_trainable=False,
+ # )
+
+ # self.node_bn = BatchNorm(
+ # irreps=self.out_node_irreps,
+ # affine=True,
+ # normalization="component",
+ # )
+
+ # self.nodetype_bn = TypeNorm(
+ # irreps=self.out_node_irreps,
+ # affine=True,
+ # num_type=n_atom,
+ # normalization="component",
+ # )
+
+ @property
+ def out_edge_irreps(self):
+ return self.idp.orbpair_irreps
+
+ @property
+ def out_node_irreps(self):
+ return self.idp.orbpair_irreps
+
+ # def set_out_scales(self, node_scales: torch.Tensor, node_shifts: torch.Tensor, edge_scales: torch.Tensor, edge_shifts: torch.Tensor):
+ # assert node_scales.shape == self.out_node_scale.scales.shape
+ # assert node_shifts.shape == self.out_node_scale.shifts.shape
+ # assert edge_scales.shape == self.out_edge_scale.scales.shape
+ # assert edge_shifts.shape == self.out_edge_scale.shifts.shape
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = with_edge_vectors(data, with_lengths=True)
+ # data = with_env_vectors(data, with_lengths=True)
+ data = with_batch(data)
+ batch = data[_keys.BATCH_KEY]
+
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_sh = self.sh(data[_keys.EDGE_VECTORS_KEY][:,[1,2,0]])
+ edge_length = data[_keys.EDGE_LENGTH_KEY]
+
+
+ data = self.onehot(data)
+ node_one_hot = data[_keys.NODE_ATTRS_KEY]
+ atom_type = data[_keys.ATOM_TYPE_KEY].flatten()
+ bond_type = data[_keys.EDGE_TYPE_KEY].flatten()
+ latents, features, cutoff_coeffs, active_edges = self.init_layer(edge_index, bond_type, edge_sh, edge_length, node_one_hot)
+
+ for layer in self.layers:
+ latents, features, cutoff_coeffs, active_edges = layer(edge_index, edge_sh, atom_type, bond_type, latents, features, cutoff_coeffs, active_edges, batch)
+ scatter_index = batch * self.n_atom + atom_type
+ latents_mean = scatter_mean(latents, scatter_index, dim=0, dim_size=(batch.max()+1)*self.n_atom)
+ latents_var = latents - latents_mean[scatter_index]
+ latents_mean = self.out_node_mean(latents_mean)[scatter_index]
+ latents_var = self.out_node_var(latents_var)
+
+ data[_keys.NODE_FEATURES_KEY] = latents_mean
+ latents_mean = self.out_node_mean_scale(data)[_keys.NODE_FEATURES_KEY]
+ data[_keys.NODE_FEATURES_KEY] = latents_var
+ data = self.out_node_var_scale(data)
+ data[_keys.NODE_FEATURES_KEY] = latents_mean + data[_keys.NODE_FEATURES_KEY]
+ data[_keys.EDGE_FEATURES_KEY] = torch.zeros(edge_index.shape[1], self.idp.orbpair_irreps.dim, dtype=self.dtype, device=self.device)
+ data[_keys.EDGE_FEATURES_KEY] = torch.index_copy(data[_keys.EDGE_FEATURES_KEY], 0, active_edges, self.out_edge(features))
+
+ return data
+
+def tp_path_exists(irreps_in1, irreps_in2, ir_out):
+ irreps_in1 = o3.Irreps(irreps_in1).simplify()
+ irreps_in2 = o3.Irreps(irreps_in2).simplify()
+ ir_out = o3.Irrep(ir_out)
+
+ for _, ir1 in irreps_in1:
+ for _, ir2 in irreps_in2:
+ if ir_out in ir1 * ir2:
+ return True
+ return False
+
+def get_gate_nonlin(irreps_in1, irreps_in2, irreps_out,
+ act={1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+ ):
+ # get gate nonlinearity after tensor product
+ # irreps_in1 and irreps_in2 are irreps to be multiplied in tensor product
+ # irreps_out is desired irreps after gate nonlin
+ # notice that nonlin.irreps_out might not be exactly equal to irreps_out
+
+ irreps_scalars = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l == 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ irreps_gated = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l > 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ if irreps_gated.dim > 0:
+ if tp_path_exists(irreps_in1, irreps_in2, "0e"):
+ ir = "0e"
+ elif tp_path_exists(irreps_in1, irreps_in2, "0o"):
+ ir = "0o"
+ warnings.warn('Using odd representations as gates')
+ else:
+ raise ValueError(
+ f"irreps_in1={irreps_in1} times irreps_in2={irreps_in2} is unable to produce gates needed for irreps_gated={irreps_gated}")
+ else:
+ ir = None
+ irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()
+
+ gate_nonlin = Gate(
+ irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ return gate_nonlin
+
+
+@compile_mode("script")
+class MakeWeightedChannels(torch.nn.Module):
+ weight_numel: int
+ multiplicity_out: Union[int, list]
+ _num_irreps: int
+
+ def __init__(
+ self,
+ irreps_in: o3.Irreps,
+ multiplicity_out: Union[int, list],
+ pad_to_alignment: int = 1,
+ ):
+ super().__init__()
+ assert all(mul == 1 for mul, _ in irreps_in)
+ assert multiplicity_out >= 1
+ # Each edgewise output multiplicity is a per-irrep weighted sum over the input
+ # So we need to apply the weight for the ith irrep to all DOF in that irrep
+ w_index = []
+ idx = 0
+ self._num_irreps = 0
+ for (mul, ir) in irreps_in:
+ w_index += sum(([ix] * ir.dim for ix in range(idx, idx + mul)), [])
+ idx += mul
+ self._num_irreps += mul
+ # w_index = sum(([i] * ir.dim for i, (mul, ir) in enumerate(irreps_in)), [])
+ # pad to padded length
+ n_pad = (
+ int(ceil(irreps_in.dim / pad_to_alignment)) * pad_to_alignment
+ - irreps_in.dim
+ )
+ # use the last weight, what we use doesn't matter much
+ w_index += [w_index[-1]] * n_pad
+ self.register_buffer("_w_index", torch.as_tensor(w_index, dtype=torch.long))
+ # there is
+ self.multiplicity_out = multiplicity_out
+ self.weight_numel = self._num_irreps * multiplicity_out
+
+ def forward(self, edge_attr, weights):
+ # weights are [z, u, num_i]
+ # edge_attr are [z, i]
+ # i runs over all irreps, which is why the weights need
+ # to be indexed in order to go from [num_i] to [i]
+ return torch.einsum(
+ "zi,zui->zui",
+ edge_attr,
+ weights.view(
+ -1,
+ self.multiplicity_out,
+ self._num_irreps,
+ )[:, :, self._w_index],
+ )
+
+@torch.jit.script
+def ShiftedSoftPlus(x: torch.Tensor):
+ return torch.nn.functional.softplus(x) - math.log(2.0)
+
+class ScalarMLPFunction(CodeGenMixin, torch.nn.Module):
+ """Module implementing an MLP according to provided options."""
+
+ in_features: int
+ out_features: int
+
+ def __init__(
+ self,
+ mlp_input_dimension: Optional[int],
+ mlp_latent_dimensions: List[int],
+ mlp_output_dimension: Optional[int],
+ mlp_nonlinearity: Optional[str] = "silu",
+ mlp_initialization: str = "normal",
+ mlp_dropout_p: float = 0.0,
+ mlp_batchnorm: bool = False,
+ ):
+ super().__init__()
+ nonlinearity = {
+ None: None,
+ "silu": torch.nn.functional.silu,
+ "ssp": ShiftedSoftPlus,
+ }[mlp_nonlinearity]
+ if nonlinearity is not None:
+ nonlin_const = normalize2mom(nonlinearity).cst
+ else:
+ nonlin_const = 1.0
+
+ dimensions = (
+ ([mlp_input_dimension] if mlp_input_dimension is not None else [])
+ + mlp_latent_dimensions
+ + ([mlp_output_dimension] if mlp_output_dimension is not None else [])
+ )
+ assert len(dimensions) >= 2 # Must have input and output dim
+ num_layers = len(dimensions) - 1
+
+ self.in_features = dimensions[0]
+ self.out_features = dimensions[-1]
+
+ # Code
+ params = {}
+ graph = fx.Graph()
+ tracer = fx.proxy.GraphAppendingTracer(graph)
+
+ def Proxy(n):
+ return fx.Proxy(n, tracer=tracer)
+
+ features = Proxy(graph.placeholder("x"))
+ norm_from_last: float = 1.0
+
+ base = torch.nn.Module()
+
+ for layer, (h_in, h_out) in enumerate(zip(dimensions, dimensions[1:])):
+ # do dropout
+ if mlp_dropout_p > 0:
+ # only dropout if it will do something
+ # dropout before linear projection- https://stats.stackexchange.com/a/245137
+ features = Proxy(graph.call_module("_dropout", (features.node,)))
+
+ # make weights
+ w = torch.empty(h_in, h_out)
+
+ if mlp_initialization == "normal":
+ w.normal_()
+ elif mlp_initialization == "uniform":
+ # these values give < x^2 > = 1
+ w.uniform_(-math.sqrt(3), math.sqrt(3))
+ elif mlp_initialization == "orthogonal":
+ # this rescaling gives < x^2 > = 1
+ torch.nn.init.orthogonal_(w, gain=math.sqrt(max(w.shape)))
+ else:
+ raise NotImplementedError(
+ f"Invalid mlp_initialization {mlp_initialization}"
+ )
+
+ # generate code
+ params[f"_weight_{layer}"] = w
+ w = Proxy(graph.get_attr(f"_weight_{layer}"))
+ w = w * (
+ norm_from_last / math.sqrt(float(h_in))
+ ) # include any nonlinearity normalization from previous layers
+ features = torch.matmul(features, w)
+
+ if mlp_batchnorm:
+ # if we call batchnorm, do it after the nonlinearity
+ features = Proxy(graph.call_module(f"_bn_{layer}", (features.node,)))
+ setattr(base, f"_bn_{layer}", torch.nn.BatchNorm1d(h_out))
+
+ # generate nonlinearity code
+ if nonlinearity is not None and layer < num_layers - 1:
+ features = nonlinearity(features)
+ # add the normalization const in next layer
+ norm_from_last = nonlin_const
+
+ graph.output(features.node)
+
+ for pname, p in params.items():
+ setattr(base, pname, torch.nn.Parameter(p))
+
+ if mlp_dropout_p > 0:
+ # with normal dropout everything blows up
+ base._dropout = torch.nn.AlphaDropout(p=mlp_dropout_p)
+
+ self._codegen_register({"_forward": fx.GraphModule(base, graph)})
+
+ def forward(self, x):
+ return self._forward(x)
+
+class InitLayer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ idp,
+ num_types: int,
+ n_radial_basis: int,
+ r_max: float,
+ irreps_sh: o3.Irreps=None,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ two_body_latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ device: Union[str, torch.device] = torch.device("cpu"),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ ):
+ super(InitLayer, self).__init__()
+ SCALAR = o3.Irrep("0e")
+ self.num_types = num_types
+ if isinstance(r_max, float) or isinstance(r_max, int):
+ self.r_max = torch.tensor(r_max, device=device, dtype=dtype)
+ self.r_max_dict = None
+ elif isinstance(r_max, dict):
+ c_set = set(list(r_max.values()))
+ self.r_max = torch.tensor(max(list(r_max.values())), device=device, dtype=dtype)
+ if len(r_max) == 1 or len(c_set) == 1:
+ self.r_max_dict = None
+ else:
+ self.r_max_dict = {}
+ for k,v in r_max.items():
+ self.r_max_dict[k] = torch.tensor(v, device=device, dtype=dtype)
+ else:
+ raise TypeError("r_max should be either float, int or dict")
+
+ self.idp = idp
+ self.two_body_latent_kwargs = two_body_latent_kwargs
+ self.r_start_cos_ratio = r_start_cos_ratio
+ self.polynomial_cutoff_p = PolynomialCutoff_p
+ self.cutoff_type = cutoff_type
+ self.device = device
+ self.dtype = dtype
+ self.irreps_out = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+ # env_embed_irreps = o3.Irreps([(1, ir) for _, ir in irreps_sh])
+ assert (
+ irreps_sh[0].ir == SCALAR
+ ), "env_embed_irreps must start with scalars"
+
+ # Node invariants for center and neighbor (chemistry)
+ # Plus edge invariants for the edge (radius).
+ self.two_body_latent = ScalarMLPFunction(
+ mlp_input_dimension=(2 * num_types + n_radial_basis),
+ mlp_output_dimension=None,
+ **two_body_latent_kwargs,
+ )
+
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=self.irreps_out,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element", # if path normalization is element and input irreps has 1 mul, it should not have effect !
+ )
+
+ # self.bn = BatchNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # instance=False,
+ # normalization="component",
+ # )
+
+ self.env_embed_mlp = ScalarMLPFunction(
+ mlp_input_dimension=self.two_body_latent.out_features,
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ **env_embed_kwargs,
+ )
+
+ self.bessel = BesselBasis(r_max=self.r_max, num_basis=n_radial_basis, trainable=True)
+
+
+
+ def forward(self, edge_index, bond_type, edge_sh, edge_length, node_one_hot):
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ edge_invariants = self.bessel(edge_length)
+ node_invariants = node_one_hot
+
+ # Vectorized precompute per layer cutoffs
+ if self.r_max_dict is None:
+ if self.cutoff_type == "cosine":
+ cutoff_coeffs = cosine_cutoff(
+ edge_length,
+ self.r_max.reshape(-1),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+
+ elif self.cutoff_type == "polynomial":
+ cutoff_coeffs = polynomial_cutoff(
+ edge_length, self.r_max.reshape(-1), p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+ else:
+ cutoff_coeffs = torch.zeros(edge_index.shape[1], dtype=self.dtype, device=self.device)
+
+ for bond, ty in self.idp.bond_to_type.items():
+ mask = bond_type == ty
+ index = mask.nonzero().squeeze(-1)
+
+ if mask.any():
+ iatom, jatom = bond.split("-")
+ if self.cutoff_type == "cosine":
+ c_coeff = cosine_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+ elif self.cutoff_type == "polynomial":
+ c_coeff = polynomial_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+
+ cutoff_coeffs = torch.index_copy(cutoff_coeffs, 0, index, c_coeff)
+
+ # Determine which edges are still in play
+ prev_mask = cutoff_coeffs > 0
+ active_edges = (cutoff_coeffs > 0).nonzero().squeeze(-1)
+
+ # Compute latents
+ latents = torch.zeros(
+ (edge_sh.shape[0], self.two_body_latent.out_features),
+ dtype=edge_sh.dtype,
+ device=edge_sh.device,
+ )
+
+ new_latents = self.two_body_latent(torch.cat([
+ node_invariants[edge_center],
+ node_invariants[edge_neighbor],
+ edge_invariants,
+ ], dim=-1)[prev_mask])
+
+ # Apply cutoff, which propagates through to everything else
+ latents = torch.index_copy(
+ latents, 0, active_edges,
+ cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ )
+ weights = self.env_embed_mlp(new_latents)
+
+ # embed initial edge
+ features = self._env_weighter(
+ edge_sh[prev_mask], weights
+ ) # features is edge_attr
+ # features = self.bn(features)
+
+ return latents, features, cutoff_coeffs, active_edges # the radial embedding x and the sperical hidden V
+
+class Layer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ num_types: int,
+ avg_num_neighbors: Optional[float] = None,
+ irreps_sh: o3.Irreps=None,
+ irreps_in: o3.Irreps=None,
+ irreps_out: o3.Irreps=None,
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_in: int=1024,
+ latent_resnet: bool = True,
+ last_layer: bool = False,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+ super().__init__()
+
+ assert latent_in == latent_kwargs["mlp_latent_dimensions"][-1]
+
+ SCALAR = o3.Irrep("0e")
+ self.latent_resnet = latent_resnet
+ self.avg_num_neighbors = avg_num_neighbors
+ self.linear_after_env_embed = linear_after_env_embed
+ self.irreps_in = irreps_in
+ self.irreps_out = irreps_out
+ self.last_layer = last_layer
+ self.dtype = dtype
+ self.device = device
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+
+ # for normalization of env embed sums
+ # one per layer
+ self.register_buffer(
+ "env_sum_normalizations",
+ # dividing by sqrt(N)
+ torch.as_tensor(avg_num_neighbors).rsqrt(),
+ )
+
+ latent = functools.partial(ScalarMLPFunction, **latent_kwargs)
+
+ self.latents = None
+ self.env_embed_mlps = None
+ self.tps = None
+ self.linears = None
+ self.env_linears = None
+
+ # Prune impossible paths
+ self.irreps_out = o3.Irreps(
+ [
+ (mul, ir)
+ for mul, ir in self.irreps_out
+ if tp_path_exists(irreps_sh, irreps_in, ir)
+ ]
+ )
+
+ mul_irreps_sh = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=mul_irreps_sh,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element",
+ )
+
+ if last_layer:
+ self._node_weighter = E3ElementLinear(
+ irreps_in=irreps_out,
+ dtype=dtype,
+ device=device,
+ )
+
+ self._edge_weighter = E3ElementLinear(
+ irreps_in=irreps_out,
+ dtype=dtype,
+ device=device,
+ )
+
+ # == Remove unneeded paths ==
+ #TODO: add the remove unseen paths
+
+ if self.linear_after_env_embed:
+ self.env_linears = Linear(
+ mul_irreps_sh,
+ mul_irreps_sh,
+ shared_weights=True,
+ internal_weights=True,
+ )
+
+ else:
+ self.env_linears = torch.nn.Identity()
+
+ # # Make TP
+ # tmp_i_out: int = 0
+ # instr = []
+ # n_scalar_outs: int = 0
+ # n_scalar_mul = []
+ # full_out_irreps = []
+ # for i_out, (mul_out, ir_out) in enumerate(self.irreps_out):
+ # for i_1, (mul1, ir_1) in enumerate(self.irreps_in): # what if feature_irreps_in has mul?
+ # for i_2, (mul2, ir_2) in enumerate(self._env_weighter.irreps_out):
+ # if ir_out in ir_1 * ir_2:
+ # if ir_out == SCALAR:
+ # n_scalar_outs += 1
+ # n_scalar_mul.append(mul2)
+ # # assert mul_out == mul1 == mul2
+ # instr.append((i_1, i_2, tmp_i_out, 'uvv', True))
+ # full_out_irreps.append((mul2, ir_out))
+ # assert full_out_irreps[-1][0] == mul2
+ # tmp_i_out += 1
+ # full_out_irreps = o3.Irreps(full_out_irreps)
+ # assert all(ir == SCALAR for _, ir in full_out_irreps[:n_scalar_outs])
+ # self.n_scalar_mul = sum(n_scalar_mul)
+
+ self.lin_pre = Linear(
+ irreps_in=self.irreps_in,
+ irreps_out=self.irreps_in,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # self.tp = TensorProduct(
+ # irreps_in1=o3.Irreps(
+ # [(mul, ir) for mul, ir in self.irreps_in]
+ # ),
+ # irreps_in2=o3.Irreps(
+ # [(mul, ir) for mul, ir in self._env_weighter.irreps_out]
+ # ),
+ # irreps_out=o3.Irreps(
+ # [(mul, ir) for mul, ir in full_out_irreps]
+ # ),
+ # irrep_normalization="component",
+ # instructions=instr,
+ # shared_weights=True,
+ # internal_weights=True,
+ # )
+ # build activation
+
+ irreps_scalar = o3.Irreps([(mul, ir) for mul, ir in self.irreps_out if ir.l == 0]).simplify()
+ irreps_gated = o3.Irreps([(mul, ir) for mul, ir in self.irreps_out if ir.l > 0]).simplify()
+
+
+ irreps_gates = o3.Irreps([(mul, (0,1)) for mul, _ in irreps_gated]).simplify()
+ act={1: torch.nn.functional.silu, -1: torch.tanh}
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+
+ self.activation = Gate(
+ irreps_scalar, [act[ir.p] for _, ir in irreps_scalar], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ self.tp = SeparateWeightTensorProduct(
+ irreps_in1=self.irreps_in,
+ irreps_in2=self._env_weighter.irreps_out,
+ irreps_out=self.activation.irreps_in,
+ )
+
+ if self.last_layer:
+ self.tp_out = SeparateWeightTensorProduct(
+ irreps_in1=self.irreps_out+self._env_weighter.irreps_out+self._env_weighter.irreps_out,
+ irreps_in2=irreps_sh,
+ irreps_out=self.irreps_out,
+ )
+
+ # self.sc = FullyConnectedTensorProduct(
+ # irreps_in,
+ # o3.Irreps(str(2*num_types)+"x0e"),
+ # self.irreps_out,
+ # shared_weights=True,
+ # internal_weights=True
+ # )
+
+ self.lin_post = Linear(
+ self.activation.irreps_out,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # self.bn = TypeNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # num_type=num_types*num_types,
+ # normalization="component",
+ # )
+
+ self.bn = BatchNorm(
+ irreps=self.irreps_out,
+ affine=True,
+ normalization="component",
+ )
+
+ if latent_resnet:
+ self.linear_res = Linear(
+ self.irreps_in,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # we extract the scalars from the first irrep of the tp
+ # assert full_out_irreps[0].ir == SCALAR
+ # self.linears = Linear(
+ # irreps_in=full_out_irreps,
+ # irreps_out=self.activation.irreps_in,
+ # shared_weights=True,
+ # internal_weights=True,
+ # biases=True,
+ # )
+
+ # the embedded latent invariants from the previous layer(s)
+ # and the invariants extracted from the last layer's TP:
+ # we need to make sure all scalars in tp.irreps_out all contains in the first irreps
+ all_tp_scalar = o3.Irreps([(mul, ir) for mul, ir in self.tp.irreps_out if ir.l == 0]).simplify()
+ assert all_tp_scalar.dim == self.tp.irreps_out[0].dim
+ self.latents = latent(
+ mlp_input_dimension=latent_in+self.tp.irreps_out[0].dim,
+ mlp_output_dimension=None,
+ )
+
+ # the env embed MLP takes the last latent's output as input
+ # and outputs enough weights for the env embedder
+ self.env_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ )
+
+ if last_layer:
+ self.node_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._node_weighter.weight_numel,
+ )
+
+ self.edge_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._edge_weighter.weight_numel,
+ )
+
+ # self.node_bn = TypeNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # num_type=num_types,
+ # normalization="component",
+ # )
+
+ self.node_bn = BatchNorm(
+ irreps=self.irreps_out,
+ affine=True,
+ normalization="component",
+ )
+
+ # - layer resnet update weights -
+ if latent_resnet_update_ratios is None:
+ # We initialize to zeros, which under the sigmoid() become 0.5
+ # so 1/2 * layer_1 + 1/4 * layer_2 + ...
+ # note that the sigmoid of these are the factor _between_ layers
+ # so the first entry is the ratio for the latent resnet of the first and second layers, etc.
+ # e.g. if there are 3 layers, there are 2 ratios: l1:l2, l2:l3
+ latent_resnet_update_params = torch.zeros(1)
+ else:
+ latent_resnet_update_ratios = torch.as_tensor(
+ latent_resnet_update_ratios, dtype=torch.get_default_dtype()
+ )
+ assert latent_resnet_update_ratios > 0.0
+ assert latent_resnet_update_ratios < 1.0
+ latent_resnet_update_params = torch.special.logit(
+ latent_resnet_update_ratios
+ )
+ # The sigmoid is mostly saturated at ±6, keep it in a reasonable range
+ latent_resnet_update_params.clamp_(-6.0, 6.0)
+
+ if latent_resnet_update_ratios_learnable:
+ self._latent_resnet_update_params = torch.nn.Parameter(
+ latent_resnet_update_params
+ )
+ else:
+ self.register_buffer(
+ "_latent_resnet_update_params", latent_resnet_update_params
+ )
+
+ def forward(self, edge_index, edge_sh, atom_type, bond_type, latents, features, cutoff_coeffs, active_edges, batch):
+ # update V
+ # update X
+ # edge_index: [2, num_edges]
+ # irreps_sh: [num_edges, irreps_sh]
+ # latents: [num_edges, latent_in]
+ # fetures: [num_active_edges, in_irreps]
+ # cutoff_coeffs: [num_edges]
+ # active_edges: [num_active_edges]
+
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ prev_mask = cutoff_coeffs > 0
+
+ # sc_features = self.sc(features, node_one_hot[edge_index].transpose(0,1).flatten(1,2)[active_edges])
+ # update V
+ weights = self.env_embed_mlps(latents[active_edges])
+
+ # Build the local environments
+ # This local environment should only be a sum over neighbors
+ # who are within the cutoff of the _current_ layer
+ # Those are the active edges, which are the only ones we
+ # have weights for (env_w) anyway.
+ # So we mask out the edges in the sum:
+ local_env_per_edge = scatter(
+ self._env_weighter(edge_sh[active_edges], weights),
+ edge_center[active_edges],
+ dim=0,
+ )
+
+ # currently, we have a sum over neighbors of constant number for each layer,
+ # the env_sum_normalization can be a scalar or list
+ # the different cutoff can be added in the future
+
+ if self.env_sum_normalizations.ndim < 1:
+ norm_const = self.env_sum_normalizations
+ else:
+ norm_const = self.env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ local_env_per_edge = local_env_per_edge * norm_const
+ local_env_per_edge = self.env_linears(local_env_per_edge)
+
+ # local_env_per_edge = torch.cat([local_env_per_edge[edge_center[active_edges]], local_env_per_edge[edge_neighbor[active_edges]]], dim=-1)
+ # local_env_per_edge = local_env_per_edge[edge_center[active_edges]]
+ # Now do the TP
+ # recursively tp current features with the environment embeddings
+ new_features = self.tp(self.lin_pre(features), local_env_per_edge[edge_center[active_edges]]) # full_out_irreps
+
+ scalars = new_features[:, :self.tp.irreps_out[0].dim]
+ new_features = self.activation(new_features)
+ # # do the linear
+ # new_features = self.linears(new_features)
+
+
+ # features has shape [N_edge, full_feature_out.dim]
+ # we know scalars are first
+ assert len(scalars.shape) == 2
+
+ new_features = self.lin_post(new_features)
+
+ # new_features = self.bn(new_features, bond_type[active_edges])
+ # new_features = new_features - scatter_mean(new_features, batch[edge_center[active_edges]], dim=0, dim_size=batch.max()+1)[batch[edge_center[active_edges]]]
+ new_features = self.bn(new_features)
+
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ features = coefficient_new * new_features + coefficient_old * self.linear_res(features)
+ else:
+ features = new_features
+
+ # whether it is the last layer
+
+ latent_inputs_to_cat = [
+ latents[active_edges],
+ scalars,
+ ]
+
+ new_latents = self.latents(torch.cat(latent_inputs_to_cat, dim=-1))
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ # At init, we assume new and old to be approximately uncorrelated
+ # Thus their variances add
+ # we always want the latent space to be normalized to variance = 1.0,
+ # because it is critical for learnability. Still, we want to preserve
+ # the _relative_ magnitudes of the current latent and the residual update
+ # to be controled by `this_layer_update_coeff`
+ # Solving the simple system for the two coefficients:
+ # a^2 + b^2 = 1 (variances add) & a * this_layer_update_coeff = b
+ # gives
+ # a = 1 / sqrt(1 + this_layer_update_coeff^2) & b = this_layer_update_coeff / sqrt(1 + this_layer_update_coeff^2)
+ # rsqrt is reciprocal sqrt
+
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ latents = torch.index_add(
+ coefficient_old * latents,
+ 0,
+ active_edges,
+ coefficient_new * new_latents,
+ )
+
+ else:
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+
+ if self.last_layer:
+ node_weights = self.node_embed_mlps(latents[active_edges])
+
+ node_features = scatter(
+ self._node_weighter(
+ features,
+ node_weights,
+ ),
+ edge_center[active_edges],
+ dim=0,
+ )
+
+ node_features = node_features * norm_const
+
+ # node_features = self.node_bn(node_features, atom_type)
+ node_features = self.node_bn(node_features)
+
+ edge_weights = self.edge_embed_mlps(latents[active_edges])
+
+ # the features's inclusion of the radial weight here is the only place
+ # where features are weighted according to the radial distance
+ features = self.tp_out(
+ torch.cat(
+ [
+ features,
+ local_env_per_edge[edge_center[active_edges]],
+ local_env_per_edge[edge_neighbor[active_edges]],
+ ], dim=-1
+ ),
+ edge_sh[active_edges],
+ )
+
+ features = self._edge_weighter(
+ features,
+ edge_weights,
+ )
+
+ return node_features, features, cutoff_coeffs, active_edges
+ else:
+ return latents, features, cutoff_coeffs, active_edges
+
\ No newline at end of file
diff --git a/dptb/nn/embedding/e3baseline_nonlocal.py b/dptb/nn/embedding/e3baseline_nonlocal.py
new file mode 100644
index 00000000..636859a0
--- /dev/null
+++ b/dptb/nn/embedding/e3baseline_nonlocal.py
@@ -0,0 +1,934 @@
+from typing import Optional, List, Union, Dict
+import math
+import functools
+import warnings
+
+import torch
+from torch_runstats.scatter import scatter
+
+from torch import fx
+from e3nn.util.codegen import CodeGenMixin
+from e3nn import o3
+from e3nn.nn import Gate, Activation
+from e3nn.nn._batchnorm import BatchNorm
+from e3nn.o3 import TensorProduct, Linear, SphericalHarmonics, FullyConnectedTensorProduct
+from e3nn.math import normalize2mom
+from e3nn.util.jit import compile_mode
+
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..radial_basis import BesselBasis
+from dptb.nn.graph_mixin import GraphModuleMixin
+from dptb.nn.embedding.from_deephe3.deephe3 import tp_path_exists
+from dptb.nn.embedding.from_deephe3.e3module import SeparateWeightTensorProduct
+from dptb.data import _keys
+from dptb.nn.cutoff import cosine_cutoff, polynomial_cutoff
+import math
+from dptb.data.transforms import OrbitalMapper
+from ..type_encode.one_hot import OneHotAtomEncoding
+from dptb.data.AtomicDataDict import with_edge_vectors, with_env_vectors, with_batch
+
+from math import ceil
+
+@Embedding.register("e3baseline_nonlocal")
+class E3BaseLineModelNonLocal(torch.nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ # required params
+ n_atom: int=1,
+ n_layers: int=3,
+ n_radial_basis: int=10,
+ r_max: float=5.0,
+ lmax: int=4,
+ irreps_hidden: o3.Irreps=None,
+ avg_num_neighbors: Optional[float] = None,
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ sh_normalized: bool = True,
+ sh_normalization: str = "component",
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [256, 256, 512],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+
+ super(E3BaseLineModelNonLocal, self).__init__()
+
+ irreps_hidden = o3.Irreps(irreps_hidden)
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ self.device = device
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb")
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.idp.get_irreps(no_parity=False)
+
+ irreps_sh=o3.Irreps([(1, (i, (-1) ** i)) for i in range(lmax + 1)])
+ orbpair_irreps = self.idp.orbpair_irreps.sort()[0].simplify()
+
+ # check if the irreps setting satisfied the requirement of idp
+ irreps_out = []
+ for mul, ir1 in irreps_hidden:
+ for _, ir2 in orbpair_irreps:
+ irreps_out += [o3.Irrep(str(irr)) for irr in ir1*ir2]
+ irreps_out = o3.Irreps(irreps_out).sort()[0].simplify()
+
+ assert all(ir in irreps_out for _, ir in orbpair_irreps), "hidden irreps should at least cover all the reqired irreps in the hamiltonian data {}".format(orbpair_irreps)
+
+ # TODO: check if the tp in first layer can produce the required irreps for hidden states
+
+ self.sh = SphericalHarmonics(
+ irreps_sh, sh_normalized, sh_normalization
+ )
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+
+ self.init_layer = InitLayer(
+ idp=self.idp,
+ num_types=n_atom,
+ n_radial_basis=n_radial_basis,
+ r_max=r_max,
+ irreps_sh=irreps_sh,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ two_body_latent_kwargs=latent_kwargs,
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio=r_start_cos_ratio,
+ PolynomialCutoff_p=PolynomialCutoff_p,
+ cutoff_type=cutoff_type,
+ device=device,
+ dtype=dtype,
+ )
+
+ self.layers = torch.nn.ModuleList()
+ latent_in =latent_kwargs["mlp_latent_dimensions"][-1]
+ # actually, we can derive the least required irreps_in and out from the idp's node and pair irreps
+ for i in range(n_layers):
+ if i == 0:
+ irreps_in = self.init_layer.irreps_out
+ else:
+ irreps_in = irreps_hidden
+
+ if i == n_layers - 1:
+ irreps_out = orbpair_irreps.sort()[0].simplify()
+ else:
+ irreps_out = irreps_hidden
+
+ self.layers.append(Layer(
+ num_types=n_atom,
+ avg_num_neighbors=avg_num_neighbors,
+ irreps_sh=irreps_sh,
+ irreps_in=irreps_in,
+ irreps_out=irreps_out,
+ # general hyperparameters:
+ linear_after_env_embed=linear_after_env_embed,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ latent_kwargs=latent_kwargs,
+ latent_in=latent_in,
+ latent_resnet=latent_resnet,
+ latent_resnet_update_ratios=latent_resnet_update_ratios,
+ latent_resnet_update_ratios_learnable=latent_resnet_update_ratios_learnable,
+ )
+ )
+
+ # initilize output_layer
+ self.out_edge = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+ self.out_node = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = with_edge_vectors(data, with_lengths=True)
+ # data = with_env_vectors(data, with_lengths=True)
+ data = with_batch(data)
+
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_sh = self.sh(data[_keys.EDGE_VECTORS_KEY][:,[1,2,0]])
+ edge_length = data[_keys.EDGE_LENGTH_KEY]
+
+
+ data = self.onehot(data)
+ node_one_hot = data[_keys.NODE_ATTRS_KEY]
+ atom_type = data[_keys.ATOM_TYPE_KEY].flatten()
+ bond_type = data[_keys.EDGE_TYPE_KEY].flatten()
+ latents, features, cutoff_coeffs, active_edges = self.init_layer(edge_index, bond_type, edge_sh, edge_length, node_one_hot)
+
+ for layer in self.layers:
+ latents, features, cutoff_coeffs, active_edges = layer(edge_index, edge_sh, atom_type, latents, features, cutoff_coeffs, active_edges)
+
+ if self.layers[-1].env_sum_normalizations.ndim < 1:
+ norm_const = self.layers[-1].env_sum_normalizations
+ else:
+ norm_const = self.layers[-1].env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ data[_keys.EDGE_FEATURES_KEY] = torch.zeros(edge_index.shape[1], self.idp.orbpair_irreps.dim, dtype=self.dtype, device=self.device)
+ data[_keys.EDGE_FEATURES_KEY] = torch.index_copy(data[_keys.EDGE_FEATURES_KEY], 0, active_edges, self.out_edge(features))
+ node_features = scatter(features, edge_index[0][active_edges], dim=0)
+ data[_keys.NODE_FEATURES_KEY] = self.out_node(node_features * norm_const)
+
+ return data
+
+def tp_path_exists(irreps_in1, irreps_in2, ir_out):
+ irreps_in1 = o3.Irreps(irreps_in1).simplify()
+ irreps_in2 = o3.Irreps(irreps_in2).simplify()
+ ir_out = o3.Irrep(ir_out)
+
+ for _, ir1 in irreps_in1:
+ for _, ir2 in irreps_in2:
+ if ir_out in ir1 * ir2:
+ return True
+ return False
+
+def get_gate_nonlin(irreps_in1, irreps_in2, irreps_out,
+ act={1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+ ):
+ # get gate nonlinearity after tensor product
+ # irreps_in1 and irreps_in2 are irreps to be multiplied in tensor product
+ # irreps_out is desired irreps after gate nonlin
+ # notice that nonlin.irreps_out might not be exactly equal to irreps_out
+
+ irreps_scalars = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l == 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ irreps_gated = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l > 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ if irreps_gated.dim > 0:
+ if tp_path_exists(irreps_in1, irreps_in2, "0e"):
+ ir = "0e"
+ elif tp_path_exists(irreps_in1, irreps_in2, "0o"):
+ ir = "0o"
+ warnings.warn('Using odd representations as gates')
+ else:
+ raise ValueError(
+ f"irreps_in1={irreps_in1} times irreps_in2={irreps_in2} is unable to produce gates needed for irreps_gated={irreps_gated}")
+ else:
+ ir = None
+ irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()
+
+ gate_nonlin = Gate(
+ irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ return gate_nonlin
+
+
+@compile_mode("script")
+class MakeWeightedChannels(torch.nn.Module):
+ weight_numel: int
+ multiplicity_out: Union[int, list]
+ _num_irreps: int
+
+ def __init__(
+ self,
+ irreps_in,
+ multiplicity_out: Union[int, list],
+ pad_to_alignment: int = 1,
+ ):
+ super().__init__()
+ assert all(mul == 1 for mul, _ in irreps_in)
+ assert multiplicity_out >= 1
+ # Each edgewise output multiplicity is a per-irrep weighted sum over the input
+ # So we need to apply the weight for the ith irrep to all DOF in that irrep
+ w_index = []
+ idx = 0
+ self._num_irreps = 0
+ for (mul, ir) in irreps_in:
+ w_index += sum(([ix] * ir.dim for ix in range(idx, idx + mul)), [])
+ idx += mul
+ self._num_irreps += mul
+ # w_index = sum(([i] * ir.dim for i, (mul, ir) in enumerate(irreps_in)), [])
+ # pad to padded length
+ n_pad = (
+ int(ceil(irreps_in.dim / pad_to_alignment)) * pad_to_alignment
+ - irreps_in.dim
+ )
+ # use the last weight, what we use doesn't matter much
+ w_index += [w_index[-1]] * n_pad
+ self.register_buffer("_w_index", torch.as_tensor(w_index, dtype=torch.long))
+ # there is
+ self.multiplicity_out = multiplicity_out
+ self.weight_numel = self._num_irreps * multiplicity_out
+
+ def forward(self, edge_attr, weights):
+ # weights are [z, u, num_i]
+ # edge_attr are [z, i]
+ # i runs over all irreps, which is why the weights need
+ # to be indexed in order to go from [num_i] to [i]
+ return torch.einsum(
+ "zi,zui->zui",
+ edge_attr,
+ weights.view(
+ -1,
+ self.multiplicity_out,
+ self._num_irreps,
+ )[:, :, self._w_index],
+ )
+
+@torch.jit.script
+def ShiftedSoftPlus(x):
+ return torch.nn.functional.softplus(x) - math.log(2.0)
+
+class ScalarMLPFunction(CodeGenMixin, torch.nn.Module):
+ """Module implementing an MLP according to provided options."""
+
+ in_features: int
+ out_features: int
+
+ def __init__(
+ self,
+ mlp_input_dimension: Optional[int],
+ mlp_latent_dimensions: List[int],
+ mlp_output_dimension: Optional[int],
+ mlp_nonlinearity: Optional[str] = "silu",
+ mlp_initialization: str = "normal",
+ mlp_dropout_p: float = 0.0,
+ mlp_batchnorm: bool = False,
+ ):
+ super().__init__()
+ nonlinearity = {
+ None: None,
+ "silu": torch.nn.functional.silu,
+ "ssp": ShiftedSoftPlus,
+ }[mlp_nonlinearity]
+ if nonlinearity is not None:
+ nonlin_const = normalize2mom(nonlinearity).cst
+ else:
+ nonlin_const = 1.0
+
+ dimensions = (
+ ([mlp_input_dimension] if mlp_input_dimension is not None else [])
+ + mlp_latent_dimensions
+ + ([mlp_output_dimension] if mlp_output_dimension is not None else [])
+ )
+ assert len(dimensions) >= 2 # Must have input and output dim
+ num_layers = len(dimensions) - 1
+
+ self.in_features = dimensions[0]
+ self.out_features = dimensions[-1]
+
+ # Code
+ params = {}
+ graph = fx.Graph()
+ tracer = fx.proxy.GraphAppendingTracer(graph)
+
+ def Proxy(n):
+ return fx.Proxy(n, tracer=tracer)
+
+ features = Proxy(graph.placeholder("x"))
+ norm_from_last: float = 1.0
+
+ base = torch.nn.Module()
+
+ for layer, (h_in, h_out) in enumerate(zip(dimensions, dimensions[1:])):
+ # do dropout
+ if mlp_dropout_p > 0:
+ # only dropout if it will do something
+ # dropout before linear projection- https://stats.stackexchange.com/a/245137
+ features = Proxy(graph.call_module("_dropout", (features.node,)))
+
+ # make weights
+ w = torch.empty(h_in, h_out)
+
+ if mlp_initialization == "normal":
+ w.normal_()
+ elif mlp_initialization == "uniform":
+ # these values give < x^2 > = 1
+ w.uniform_(-math.sqrt(3), math.sqrt(3))
+ elif mlp_initialization == "orthogonal":
+ # this rescaling gives < x^2 > = 1
+ torch.nn.init.orthogonal_(w, gain=math.sqrt(max(w.shape)))
+ else:
+ raise NotImplementedError(
+ f"Invalid mlp_initialization {mlp_initialization}"
+ )
+
+ # generate code
+ params[f"_weight_{layer}"] = w
+ w = Proxy(graph.get_attr(f"_weight_{layer}"))
+ w = w * (
+ norm_from_last / math.sqrt(float(h_in))
+ ) # include any nonlinearity normalization from previous layers
+ features = torch.matmul(features, w)
+
+ if mlp_batchnorm:
+ # if we call batchnorm, do it after the nonlinearity
+ features = Proxy(graph.call_module(f"_bn_{layer}", (features.node,)))
+ setattr(base, f"_bn_{layer}", torch.nn.BatchNorm1d(h_out))
+
+ # generate nonlinearity code
+ if nonlinearity is not None and layer < num_layers - 1:
+ features = nonlinearity(features)
+ # add the normalization const in next layer
+ norm_from_last = nonlin_const
+
+ graph.output(features.node)
+
+ for pname, p in params.items():
+ setattr(base, pname, torch.nn.Parameter(p))
+
+ if mlp_dropout_p > 0:
+ # with normal dropout everything blows up
+ base._dropout = torch.nn.AlphaDropout(p=mlp_dropout_p)
+
+ self._codegen_register({"_forward": fx.GraphModule(base, graph)})
+
+ def forward(self, x):
+ return self._forward(x)
+
+class InitLayer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ idp,
+ num_types: int,
+ n_radial_basis: int,
+ r_max: float,
+ irreps_sh: o3.Irreps=None,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ two_body_latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ device: Union[str, torch.device] = torch.device("cpu"),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ ):
+ super(InitLayer, self).__init__()
+ SCALAR = o3.Irrep("0e")
+ self.num_types = num_types
+ if isinstance(r_max, float) or isinstance(r_max, int):
+ self.r_max = torch.tensor(r_max, device=device, dtype=dtype)
+ self.r_max_dict = None
+ elif isinstance(r_max, dict):
+ c_set = set(list(r_max.values()))
+ self.r_max = torch.tensor(max(list(r_max.values())), device=device, dtype=dtype)
+ if len(r_max) == 1 or len(c_set) == 1:
+ self.r_max_dict = None
+ else:
+ self.r_max_dict = {}
+ for k,v in r_max.items():
+ self.r_max_dict[k] = torch.tensor(v, device=device, dtype=dtype)
+ else:
+ raise TypeError("r_max should be either float, int or dict")
+
+ self.idp = idp
+ self.two_body_latent_kwargs = two_body_latent_kwargs
+ self.r_start_cos_ratio = r_start_cos_ratio
+ self.polynomial_cutoff_p = PolynomialCutoff_p
+ self.cutoff_type = cutoff_type
+ self.device = device
+ self.dtype = dtype
+ self.irreps_out = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+ # env_embed_irreps = o3.Irreps([(1, ir) for _, ir in irreps_sh])
+ assert (
+ irreps_sh[0].ir == SCALAR
+ ), "env_embed_irreps must start with scalars"
+
+ # Node invariants for center and neighbor (chemistry)
+ # Plus edge invariants for the edge (radius).
+ self.two_body_latent = ScalarMLPFunction(
+ mlp_input_dimension=(2 * num_types + n_radial_basis),
+ mlp_output_dimension=None,
+ **two_body_latent_kwargs,
+ )
+
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=self.irreps_out,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element", # if path normalization is element and input irreps has 1 mul, it should not have effect !
+ )
+
+ # self.bn = BatchNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # instance=False,
+ # normalization="component",
+ # )
+
+ self.env_embed_mlp = ScalarMLPFunction(
+ mlp_input_dimension=self.two_body_latent.out_features,
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ **env_embed_kwargs,
+ )
+
+ self.bessel = BesselBasis(r_max=self.r_max, num_basis=n_radial_basis, trainable=True)
+
+
+
+ def forward(self, edge_index, bond_type, edge_sh, edge_length, node_one_hot):
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ edge_invariants = self.bessel(edge_length)
+ node_invariants = node_one_hot
+
+ # Vectorized precompute per layer cutoffs
+ if self.r_max_dict is None:
+ if self.cutoff_type == "cosine":
+ cutoff_coeffs = cosine_cutoff(
+ edge_length,
+ self.r_max.reshape(-1),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+
+ elif self.cutoff_type == "polynomial":
+ cutoff_coeffs = polynomial_cutoff(
+ edge_length, self.r_max.reshape(-1), p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+ else:
+ cutoff_coeffs = torch.zeros(edge_index.shape[1], dtype=self.dtype, device=self.device)
+
+ for bond, ty in self.idp.bond_to_type.items():
+ mask = bond_type == ty
+ index = mask.nonzero().squeeze(-1)
+
+ if mask.any():
+ iatom, jatom = bond.split("-")
+ if self.cutoff_type == "cosine":
+ c_coeff = cosine_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+ elif self.cutoff_type == "polynomial":
+ c_coeff = polynomial_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+
+ cutoff_coeffs = torch.index_copy(cutoff_coeffs, 0, index, c_coeff)
+
+ # Determine which edges are still in play
+ prev_mask = cutoff_coeffs > 0
+ active_edges = (cutoff_coeffs > 0).nonzero().squeeze(-1)
+
+ # Compute latents
+ latents = torch.zeros(
+ (edge_sh.shape[0], self.two_body_latent.out_features),
+ dtype=edge_sh.dtype,
+ device=edge_sh.device,
+ )
+
+ new_latents = self.two_body_latent(torch.cat([
+ node_invariants[edge_center],
+ node_invariants[edge_neighbor],
+ edge_invariants,
+ ], dim=-1)[prev_mask])
+
+ # Apply cutoff, which propagates through to everything else
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+ weights = self.env_embed_mlp(latents[active_edges])
+
+ # embed initial edge
+ features = self._env_weighter(
+ edge_sh[prev_mask], weights
+ ) # features is edge_attr
+ # features = self.bn(features)
+
+ return latents, features, cutoff_coeffs, active_edges # the radial embedding x and the sperical hidden V
+
+class Layer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ num_types: int,
+ avg_num_neighbors: Optional[float] = None,
+ irreps_sh: o3.Irreps=None,
+ irreps_in: o3.Irreps=None,
+ irreps_out: o3.Irreps=None,
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_in: int=1024,
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ ):
+ super().__init__()
+ SCALAR = o3.Irrep("0e")
+ self.latent_resnet = latent_resnet
+ self.avg_num_neighbors = avg_num_neighbors
+ self.linear_after_env_embed = linear_after_env_embed
+ self.irreps_in = irreps_in
+ self.irreps_out = irreps_out
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+
+ # for normalization of env embed sums
+ # one per layer
+ self.register_buffer(
+ "env_sum_normalizations",
+ # dividing by sqrt(N)
+ torch.as_tensor(avg_num_neighbors).rsqrt(),
+ )
+
+ latent = functools.partial(ScalarMLPFunction, **latent_kwargs)
+
+ self.latents = None
+ self.env_embed_mlps = None
+ self.tps = None
+ self.linears = None
+ self.env_linears = None
+
+ # Prune impossible paths
+ self.irreps_out = o3.Irreps(
+ [
+ (mul, ir)
+ for mul, ir in self.irreps_out
+ if tp_path_exists(irreps_sh, irreps_in, ir)
+ ]
+ )
+
+ mul_irreps_sh = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=mul_irreps_sh,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element",
+ )
+
+ # == Remove unneeded paths ==
+ #TODO: add the remove unseen paths
+
+ if self.linear_after_env_embed:
+ self.env_linears = Linear(
+ mul_irreps_sh,
+ mul_irreps_sh,
+ shared_weights=True,
+ internal_weights=True,
+ )
+ else:
+ self.env_linears = torch.nn.Identity()
+
+ # # Make TP
+ # tmp_i_out: int = 0
+ # instr = []
+ # n_scalar_outs: int = 0
+ # n_scalar_mul = []
+ # full_out_irreps = []
+ # for i_out, (mul_out, ir_out) in enumerate(self.irreps_out):
+ # for i_1, (mul1, ir_1) in enumerate(self.irreps_in): # what if feature_irreps_in has mul?
+ # for i_2, (mul2, ir_2) in enumerate(self._env_weighter.irreps_out):
+ # if ir_out in ir_1 * ir_2:
+ # if ir_out == SCALAR:
+ # n_scalar_outs += 1
+ # n_scalar_mul.append(mul2)
+ # # assert mul_out == mul1 == mul2
+ # instr.append((i_1, i_2, tmp_i_out, 'uvv', True))
+ # full_out_irreps.append((mul2, ir_out))
+ # assert full_out_irreps[-1][0] == mul2
+ # tmp_i_out += 1
+ # full_out_irreps = o3.Irreps(full_out_irreps)
+ # assert all(ir == SCALAR for _, ir in full_out_irreps[:n_scalar_outs])
+ # self.n_scalar_mul = sum(n_scalar_mul)
+
+ self.lin_pre = Linear(
+ irreps_in=self.irreps_in,
+ irreps_out=self.irreps_in,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # self.tp = TensorProduct(
+ # irreps_in1=o3.Irreps(
+ # [(mul, ir) for mul, ir in self.irreps_in]
+ # ),
+ # irreps_in2=o3.Irreps(
+ # [(mul, ir) for mul, ir in self._env_weighter.irreps_out]
+ # ),
+ # irreps_out=o3.Irreps(
+ # [(mul, ir) for mul, ir in full_out_irreps]
+ # ),
+ # irrep_normalization="component",
+ # instructions=instr,
+ # shared_weights=True,
+ # internal_weights=True,
+ # )
+ # build activation
+
+ irreps_scalar = o3.Irreps(str(self.irreps_out[0]))
+ irreps_gated = o3.Irreps([(mul, ir) for mul, ir in self.irreps_out if ir.l > 0]).simplify()
+ irreps_gates = o3.Irreps([(mul, (0,1)) for mul, _ in irreps_gated]).simplify()
+ act={1: torch.nn.functional.silu, -1: torch.tanh}
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+
+ self.activation = Gate(
+ irreps_scalar, [act[ir.p] for _, ir in irreps_scalar], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ self.tp = SeparateWeightTensorProduct(
+ irreps_in1=self._env_weighter.irreps_out+self.irreps_in+self._env_weighter.irreps_out,
+ irreps_in2=irreps_sh,
+ irreps_out=self.activation.irreps_in,
+ )
+
+ # self.sc = FullyConnectedTensorProduct(
+ # irreps_in,
+ # o3.Irreps(str(2*num_types)+"x0e"),
+ # self.irreps_out,
+ # shared_weights=True,
+ # internal_weights=True
+ # )
+
+ self.lin_post = Linear(
+ self.irreps_out,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ self.bn = BatchNorm(
+ irreps=self.irreps_out,
+ affine=True,
+ instance=False,
+ normalization="component",
+ )
+
+ self.linear_res = Linear(
+ self.irreps_in,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # we extract the scalars from the first irrep of the tp
+ # assert full_out_irreps[0].ir == SCALAR
+ # self.linears = Linear(
+ # irreps_in=full_out_irreps,
+ # irreps_out=self.activation.irreps_in,
+ # shared_weights=True,
+ # internal_weights=True,
+ # biases=True,
+ # )
+
+ # the embedded latent invariants from the previous layer(s)
+ # and the invariants extracted from the last layer's TP:
+ self.latents = latent(
+ mlp_input_dimension=latent_in+self.irreps_out[0].dim,
+ mlp_output_dimension=None,
+ )
+
+ # the env embed MLP takes the last latent's output as input
+ # and outputs enough weights for the env embedder
+ self.env_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ )
+ # - layer resnet update weights -
+ if latent_resnet_update_ratios is None:
+ # We initialize to zeros, which under the sigmoid() become 0.5
+ # so 1/2 * layer_1 + 1/4 * layer_2 + ...
+ # note that the sigmoid of these are the factor _between_ layers
+ # so the first entry is the ratio for the latent resnet of the first and second layers, etc.
+ # e.g. if there are 3 layers, there are 2 ratios: l1:l2, l2:l3
+ latent_resnet_update_params = torch.zeros(1)
+ else:
+ latent_resnet_update_ratios = torch.as_tensor(
+ latent_resnet_update_ratios, dtype=torch.get_default_dtype()
+ )
+ assert latent_resnet_update_ratios > 0.0
+ assert latent_resnet_update_ratios < 1.0
+ latent_resnet_update_params = torch.special.logit(
+ latent_resnet_update_ratios
+ )
+ # The sigmoid is mostly saturated at ±6, keep it in a reasonable range
+ latent_resnet_update_params.clamp_(-6.0, 6.0)
+
+ if latent_resnet_update_ratios_learnable:
+ self._latent_resnet_update_params = torch.nn.Parameter(
+ latent_resnet_update_params
+ )
+ else:
+ self.register_buffer(
+ "_latent_resnet_update_params", latent_resnet_update_params
+ )
+
+ def forward(self, edge_index, edge_sh, atom_type, latents, features, cutoff_coeffs, active_edges):
+ # update V
+ # update X
+ # edge_index: [2, num_edges]
+ # irreps_sh: [num_edges, irreps_sh]
+ # latents: [num_edges, latent_in]
+ # fetures: [num_active_edges, in_irreps]
+ # cutoff_coeffs: [num_edges]
+ # active_edges: [num_active_edges]
+
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ prev_mask = cutoff_coeffs > 0
+
+ # sc_features = self.sc(features, node_one_hot[edge_index].transpose(0,1).flatten(1,2)[active_edges])
+ # update V
+ weights = self.env_embed_mlps(latents[active_edges])
+
+ # Build the local environments
+ # This local environment should only be a sum over neighbors
+ # who are within the cutoff of the _current_ layer
+ # Those are the active edges, which are the only ones we
+ # have weights for (env_w) anyway.
+ # So we mask out the edges in the sum:
+ local_env_per_edge = scatter(
+ self._env_weighter(edge_sh[active_edges], weights),
+ edge_center[active_edges],
+ dim=0,
+ )
+
+ # currently, we have a sum over neighbors of constant number for each layer,
+ # the env_sum_normalization can be a scalar or list
+ # the different cutoff can be added in the future
+
+ if self.env_sum_normalizations.ndim < 1:
+ norm_const = self.env_sum_normalizations
+ else:
+ norm_const = self.env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ local_env_per_edge = local_env_per_edge * norm_const
+ local_env_per_edge = self.env_linears(local_env_per_edge)
+
+ # local_env_per_edge = torch.cat([local_env_per_edge[edge_center[active_edges]], local_env_per_edge[edge_neighbor[active_edges]]], dim=-1)
+ # local_env_per_edge = local_env_per_edge[edge_center[active_edges]]
+ # Now do the TP
+ # recursively tp current features with the environment embeddings
+ new_features = self.tp(
+ torch.cat(
+ [
+ local_env_per_edge[edge_center[active_edges]],
+ self.lin_pre(features),
+ local_env_per_edge[edge_neighbor[active_edges]]
+ ], dim=-1),
+ edge_sh[active_edges]) # full_out_irreps
+
+ new_features = self.activation(new_features)
+ # # do the linear
+ # new_features = self.linears(new_features)
+
+
+ # features has shape [N_edge, full_feature_out.dim]
+ # we know scalars are first
+ scalars = new_features[:, :self.irreps_out[0].dim]
+ assert len(scalars.shape) == 2
+
+ new_features = self.lin_post(new_features)
+
+ new_features = self.bn(new_features)
+
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ features = coefficient_new * new_features + coefficient_old * self.linear_res(features)
+ else:
+ features = new_features
+
+ # update X
+ latent_inputs_to_cat = [
+ latents[active_edges],
+ scalars,
+ ]
+
+ new_latents = self.latents(torch.cat(latent_inputs_to_cat, dim=-1))
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ # At init, we assume new and old to be approximately uncorrelated
+ # Thus their variances add
+ # we always want the latent space to be normalized to variance = 1.0,
+ # because it is critical for learnability. Still, we want to preserve
+ # the _relative_ magnitudes of the current latent and the residual update
+ # to be controled by `this_layer_update_coeff`
+ # Solving the simple system for the two coefficients:
+ # a^2 + b^2 = 1 (variances add) & a * this_layer_update_coeff = b
+ # gives
+ # a = 1 / sqrt(1 + this_layer_update_coeff^2) & b = this_layer_update_coeff / sqrt(1 + this_layer_update_coeff^2)
+ # rsqrt is reciprocal sqrt
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ latents = torch.index_add(
+ coefficient_old * latents,
+ 0,
+ active_edges,
+ coefficient_new * new_latents,
+ )
+ else:
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+
+ return latents, features, cutoff_coeffs, active_edges
+
\ No newline at end of file
diff --git a/dptb/nn/embedding/e3baseline_nonlocal_wnode.py b/dptb/nn/embedding/e3baseline_nonlocal_wnode.py
new file mode 100644
index 00000000..ab45b77a
--- /dev/null
+++ b/dptb/nn/embedding/e3baseline_nonlocal_wnode.py
@@ -0,0 +1,934 @@
+from typing import Optional, List, Union, Dict
+import math
+import functools
+import warnings
+
+import torch
+from torch_runstats.scatter import scatter
+
+from torch import fx
+from e3nn.util.codegen import CodeGenMixin
+from e3nn import o3
+from e3nn.nn import Gate, Activation
+from e3nn.nn._batchnorm import BatchNorm
+from e3nn.o3 import TensorProduct, Linear, SphericalHarmonics, FullyConnectedTensorProduct
+from e3nn.math import normalize2mom
+from e3nn.util.jit import compile_mode
+
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..radial_basis import BesselBasis
+from dptb.nn.graph_mixin import GraphModuleMixin
+from dptb.nn.embedding.from_deephe3.deephe3 import tp_path_exists
+from dptb.nn.embedding.from_deephe3.e3module import SeparateWeightTensorProduct
+from dptb.data import _keys
+from dptb.nn.cutoff import cosine_cutoff, polynomial_cutoff
+import math
+from dptb.data.transforms import OrbitalMapper
+from ..type_encode.one_hot import OneHotAtomEncoding
+from dptb.data.AtomicDataDict import with_edge_vectors, with_env_vectors, with_batch
+
+from math import ceil
+
+@Embedding.register("e3baseline_nonlocal_wnode")
+class E3BaseLineModelNonLocalWNODE(torch.nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ # required params
+ n_atom: int=1,
+ n_layers: int=3,
+ n_radial_basis: int=10,
+ r_max: float=5.0,
+ lmax: int=4,
+ irreps_hidden: o3.Irreps=None,
+ avg_num_neighbors: Optional[float] = None,
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ sh_normalized: bool = True,
+ sh_normalization: str = "component",
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [256, 256, 512],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+
+ super(E3BaseLineModelNonLocalWNODE, self).__init__()
+
+ irreps_hidden = o3.Irreps(irreps_hidden)
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ self.device = device
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb")
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.idp.get_irreps(no_parity=False)
+
+ irreps_sh=o3.Irreps([(1, (i, (-1) ** i)) for i in range(lmax + 1)])
+ orbpair_irreps = self.idp.orbpair_irreps.sort()[0].simplify()
+
+ # check if the irreps setting satisfied the requirement of idp
+ irreps_out = []
+ for mul, ir1 in irreps_hidden:
+ for _, ir2 in orbpair_irreps:
+ irreps_out += [o3.Irrep(str(irr)) for irr in ir1*ir2]
+ irreps_out = o3.Irreps(irreps_out).sort()[0].simplify()
+
+ assert all(ir in irreps_out for _, ir in orbpair_irreps), "hidden irreps should at least cover all the reqired irreps in the hamiltonian data {}".format(orbpair_irreps)
+
+ # TODO: check if the tp in first layer can produce the required irreps for hidden states
+
+ self.sh = SphericalHarmonics(
+ irreps_sh, sh_normalized, sh_normalization
+ )
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+
+ self.init_layer = InitLayer(
+ idp=self.idp,
+ num_types=n_atom,
+ n_radial_basis=n_radial_basis,
+ r_max=r_max,
+ irreps_sh=irreps_sh,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ two_body_latent_kwargs=latent_kwargs,
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio=r_start_cos_ratio,
+ PolynomialCutoff_p=PolynomialCutoff_p,
+ cutoff_type=cutoff_type,
+ device=device,
+ dtype=dtype,
+ )
+
+ self.layers = torch.nn.ModuleList()
+ latent_in =latent_kwargs["mlp_latent_dimensions"][-1]
+ # actually, we can derive the least required irreps_in and out from the idp's node and pair irreps
+ for i in range(n_layers):
+ if i == 0:
+ irreps_in = self.init_layer.irreps_out
+ else:
+ irreps_in = irreps_hidden
+
+ if i == n_layers - 1:
+ irreps_out = orbpair_irreps.sort()[0].simplify()
+ else:
+ irreps_out = irreps_hidden
+
+ self.layers.append(Layer(
+ num_types=n_atom,
+ avg_num_neighbors=avg_num_neighbors,
+ irreps_sh=irreps_sh,
+ irreps_in=irreps_in,
+ irreps_out=irreps_out,
+ # general hyperparameters:
+ linear_after_env_embed=linear_after_env_embed,
+ env_embed_multiplicity=env_embed_multiplicity,
+ # MLP parameters:
+ latent_kwargs=latent_kwargs,
+ latent_in=latent_in,
+ latent_resnet=latent_resnet,
+ latent_resnet_update_ratios=latent_resnet_update_ratios,
+ latent_resnet_update_ratios_learnable=latent_resnet_update_ratios_learnable,
+ )
+ )
+
+ # initilize output_layer
+ self.out_edge = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+ self.out_node = Linear(self.layers[-1].irreps_out, self.idp.orbpair_irreps, shared_weights=True, internal_weights=True, biases=True)
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = with_edge_vectors(data, with_lengths=True)
+ # data = with_env_vectors(data, with_lengths=True)
+ data = with_batch(data)
+
+ edge_index = data[_keys.EDGE_INDEX_KEY]
+ edge_sh = self.sh(data[_keys.EDGE_VECTORS_KEY][:,[1,2,0]])
+ edge_length = data[_keys.EDGE_LENGTH_KEY]
+
+
+ data = self.onehot(data)
+ node_one_hot = data[_keys.NODE_ATTRS_KEY]
+ atom_type = data[_keys.ATOM_TYPE_KEY].flatten()
+ bond_type = data[_keys.EDGE_TYPE_KEY].flatten()
+ latents, features, cutoff_coeffs, active_edges = self.init_layer(edge_index, bond_type, edge_sh, edge_length, node_one_hot)
+
+ for layer in self.layers:
+ latents, features, cutoff_coeffs, active_edges = layer(edge_index, edge_sh, atom_type, latents, features, cutoff_coeffs, active_edges)
+
+ if self.layers[-1].env_sum_normalizations.ndim < 1:
+ norm_const = self.layers[-1].env_sum_normalizations
+ else:
+ norm_const = self.layers[-1].env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ data[_keys.EDGE_FEATURES_KEY] = torch.zeros(edge_index.shape[1], self.idp.orbpair_irreps.dim, dtype=self.dtype, device=self.device)
+ data[_keys.EDGE_FEATURES_KEY] = torch.index_copy(data[_keys.EDGE_FEATURES_KEY], 0, active_edges, self.out_edge(features))
+ node_features = scatter(features, edge_index[0][active_edges], dim=0)
+ data[_keys.NODE_FEATURES_KEY] = self.out_node(node_features * norm_const)
+
+ return data
+
+def tp_path_exists(irreps_in1, irreps_in2, ir_out):
+ irreps_in1 = o3.Irreps(irreps_in1).simplify()
+ irreps_in2 = o3.Irreps(irreps_in2).simplify()
+ ir_out = o3.Irrep(ir_out)
+
+ for _, ir1 in irreps_in1:
+ for _, ir2 in irreps_in2:
+ if ir_out in ir1 * ir2:
+ return True
+ return False
+
+def get_gate_nonlin(irreps_in1, irreps_in2, irreps_out,
+ act={1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+ ):
+ # get gate nonlinearity after tensor product
+ # irreps_in1 and irreps_in2 are irreps to be multiplied in tensor product
+ # irreps_out is desired irreps after gate nonlin
+ # notice that nonlin.irreps_out might not be exactly equal to irreps_out
+
+ irreps_scalars = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l == 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ irreps_gated = o3.Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l > 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ if irreps_gated.dim > 0:
+ if tp_path_exists(irreps_in1, irreps_in2, "0e"):
+ ir = "0e"
+ elif tp_path_exists(irreps_in1, irreps_in2, "0o"):
+ ir = "0o"
+ warnings.warn('Using odd representations as gates')
+ else:
+ raise ValueError(
+ f"irreps_in1={irreps_in1} times irreps_in2={irreps_in2} is unable to produce gates needed for irreps_gated={irreps_gated}")
+ else:
+ ir = None
+ irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()
+
+ gate_nonlin = Gate(
+ irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ return gate_nonlin
+
+
+@compile_mode("script")
+class MakeWeightedChannels(torch.nn.Module):
+ weight_numel: int
+ multiplicity_out: Union[int, list]
+ _num_irreps: int
+
+ def __init__(
+ self,
+ irreps_in,
+ multiplicity_out: Union[int, list],
+ pad_to_alignment: int = 1,
+ ):
+ super().__init__()
+ assert all(mul == 1 for mul, _ in irreps_in)
+ assert multiplicity_out >= 1
+ # Each edgewise output multiplicity is a per-irrep weighted sum over the input
+ # So we need to apply the weight for the ith irrep to all DOF in that irrep
+ w_index = []
+ idx = 0
+ self._num_irreps = 0
+ for (mul, ir) in irreps_in:
+ w_index += sum(([ix] * ir.dim for ix in range(idx, idx + mul)), [])
+ idx += mul
+ self._num_irreps += mul
+ # w_index = sum(([i] * ir.dim for i, (mul, ir) in enumerate(irreps_in)), [])
+ # pad to padded length
+ n_pad = (
+ int(ceil(irreps_in.dim / pad_to_alignment)) * pad_to_alignment
+ - irreps_in.dim
+ )
+ # use the last weight, what we use doesn't matter much
+ w_index += [w_index[-1]] * n_pad
+ self.register_buffer("_w_index", torch.as_tensor(w_index, dtype=torch.long))
+ # there is
+ self.multiplicity_out = multiplicity_out
+ self.weight_numel = self._num_irreps * multiplicity_out
+
+ def forward(self, edge_attr, weights):
+ # weights are [z, u, num_i]
+ # edge_attr are [z, i]
+ # i runs over all irreps, which is why the weights need
+ # to be indexed in order to go from [num_i] to [i]
+ return torch.einsum(
+ "zi,zui->zui",
+ edge_attr,
+ weights.view(
+ -1,
+ self.multiplicity_out,
+ self._num_irreps,
+ )[:, :, self._w_index],
+ )
+
+@torch.jit.script
+def ShiftedSoftPlus(x):
+ return torch.nn.functional.softplus(x) - math.log(2.0)
+
+class ScalarMLPFunction(CodeGenMixin, torch.nn.Module):
+ """Module implementing an MLP according to provided options."""
+
+ in_features: int
+ out_features: int
+
+ def __init__(
+ self,
+ mlp_input_dimension: Optional[int],
+ mlp_latent_dimensions: List[int],
+ mlp_output_dimension: Optional[int],
+ mlp_nonlinearity: Optional[str] = "silu",
+ mlp_initialization: str = "normal",
+ mlp_dropout_p: float = 0.0,
+ mlp_batchnorm: bool = False,
+ ):
+ super().__init__()
+ nonlinearity = {
+ None: None,
+ "silu": torch.nn.functional.silu,
+ "ssp": ShiftedSoftPlus,
+ }[mlp_nonlinearity]
+ if nonlinearity is not None:
+ nonlin_const = normalize2mom(nonlinearity).cst
+ else:
+ nonlin_const = 1.0
+
+ dimensions = (
+ ([mlp_input_dimension] if mlp_input_dimension is not None else [])
+ + mlp_latent_dimensions
+ + ([mlp_output_dimension] if mlp_output_dimension is not None else [])
+ )
+ assert len(dimensions) >= 2 # Must have input and output dim
+ num_layers = len(dimensions) - 1
+
+ self.in_features = dimensions[0]
+ self.out_features = dimensions[-1]
+
+ # Code
+ params = {}
+ graph = fx.Graph()
+ tracer = fx.proxy.GraphAppendingTracer(graph)
+
+ def Proxy(n):
+ return fx.Proxy(n, tracer=tracer)
+
+ features = Proxy(graph.placeholder("x"))
+ norm_from_last: float = 1.0
+
+ base = torch.nn.Module()
+
+ for layer, (h_in, h_out) in enumerate(zip(dimensions, dimensions[1:])):
+ # do dropout
+ if mlp_dropout_p > 0:
+ # only dropout if it will do something
+ # dropout before linear projection- https://stats.stackexchange.com/a/245137
+ features = Proxy(graph.call_module("_dropout", (features.node,)))
+
+ # make weights
+ w = torch.empty(h_in, h_out)
+
+ if mlp_initialization == "normal":
+ w.normal_()
+ elif mlp_initialization == "uniform":
+ # these values give < x^2 > = 1
+ w.uniform_(-math.sqrt(3), math.sqrt(3))
+ elif mlp_initialization == "orthogonal":
+ # this rescaling gives < x^2 > = 1
+ torch.nn.init.orthogonal_(w, gain=math.sqrt(max(w.shape)))
+ else:
+ raise NotImplementedError(
+ f"Invalid mlp_initialization {mlp_initialization}"
+ )
+
+ # generate code
+ params[f"_weight_{layer}"] = w
+ w = Proxy(graph.get_attr(f"_weight_{layer}"))
+ w = w * (
+ norm_from_last / math.sqrt(float(h_in))
+ ) # include any nonlinearity normalization from previous layers
+ features = torch.matmul(features, w)
+
+ if mlp_batchnorm:
+ # if we call batchnorm, do it after the nonlinearity
+ features = Proxy(graph.call_module(f"_bn_{layer}", (features.node,)))
+ setattr(base, f"_bn_{layer}", torch.nn.BatchNorm1d(h_out))
+
+ # generate nonlinearity code
+ if nonlinearity is not None and layer < num_layers - 1:
+ features = nonlinearity(features)
+ # add the normalization const in next layer
+ norm_from_last = nonlin_const
+
+ graph.output(features.node)
+
+ for pname, p in params.items():
+ setattr(base, pname, torch.nn.Parameter(p))
+
+ if mlp_dropout_p > 0:
+ # with normal dropout everything blows up
+ base._dropout = torch.nn.AlphaDropout(p=mlp_dropout_p)
+
+ self._codegen_register({"_forward": fx.GraphModule(base, graph)})
+
+ def forward(self, x):
+ return self._forward(x)
+
+class InitLayer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ idp,
+ num_types: int,
+ n_radial_basis: int,
+ r_max: float,
+ irreps_sh: o3.Irreps=None,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ two_body_latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ env_embed_kwargs = {
+ "mlp_latent_dimensions": [],
+ "mlp_nonlinearity": None,
+ "mlp_initialization": "uniform"
+ },
+ # cutoffs
+ r_start_cos_ratio: float = 0.8,
+ PolynomialCutoff_p: float = 6,
+ cutoff_type: str = "polynomial",
+ device: Union[str, torch.device] = torch.device("cpu"),
+ dtype: Union[str, torch.dtype] = torch.float32,
+ ):
+ super(InitLayer, self).__init__()
+ SCALAR = o3.Irrep("0e")
+ self.num_types = num_types
+ if isinstance(r_max, float) or isinstance(r_max, int):
+ self.r_max = torch.tensor(r_max, device=device, dtype=dtype)
+ self.r_max_dict = None
+ elif isinstance(r_max, dict):
+ c_set = set(list(r_max.values()))
+ self.r_max = torch.tensor(max(list(r_max.values())), device=device, dtype=dtype)
+ if len(r_max) == 1 or len(c_set) == 1:
+ self.r_max_dict = None
+ else:
+ self.r_max_dict = {}
+ for k,v in r_max.items():
+ self.r_max_dict[k] = torch.tensor(v, device=device, dtype=dtype)
+ else:
+ raise TypeError("r_max should be either float, int or dict")
+
+ self.idp = idp
+ self.two_body_latent_kwargs = two_body_latent_kwargs
+ self.r_start_cos_ratio = r_start_cos_ratio
+ self.polynomial_cutoff_p = PolynomialCutoff_p
+ self.cutoff_type = cutoff_type
+ self.device = device
+ self.dtype = dtype
+ self.irreps_out = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+ # env_embed_irreps = o3.Irreps([(1, ir) for _, ir in irreps_sh])
+ assert (
+ irreps_sh[0].ir == SCALAR
+ ), "env_embed_irreps must start with scalars"
+
+ # Node invariants for center and neighbor (chemistry)
+ # Plus edge invariants for the edge (radius).
+ self.two_body_latent = ScalarMLPFunction(
+ mlp_input_dimension=(2 * num_types + n_radial_basis),
+ mlp_output_dimension=None,
+ **two_body_latent_kwargs,
+ )
+
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=self.irreps_out,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element", # if path normalization is element and input irreps has 1 mul, it should not have effect !
+ )
+
+ # self.bn = BatchNorm(
+ # irreps=self.irreps_out,
+ # affine=True,
+ # instance=False,
+ # normalization="component",
+ # )
+
+ self.env_embed_mlp = ScalarMLPFunction(
+ mlp_input_dimension=self.two_body_latent.out_features,
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ **env_embed_kwargs,
+ )
+
+ self.bessel = BesselBasis(r_max=self.r_max, num_basis=n_radial_basis, trainable=True)
+
+
+
+ def forward(self, edge_index, bond_type, edge_sh, edge_length, node_one_hot):
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ edge_invariants = self.bessel(edge_length)
+ node_invariants = node_one_hot
+
+ # Vectorized precompute per layer cutoffs
+ if self.r_max_dict is None:
+ if self.cutoff_type == "cosine":
+ cutoff_coeffs = cosine_cutoff(
+ edge_length,
+ self.r_max.reshape(-1),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+
+ elif self.cutoff_type == "polynomial":
+ cutoff_coeffs = polynomial_cutoff(
+ edge_length, self.r_max.reshape(-1), p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+ else:
+ cutoff_coeffs = torch.zeros(edge_index.shape[1], dtype=self.dtype, device=self.device)
+
+ for bond, ty in self.idp.bond_to_type.items():
+ mask = bond_type == ty
+ index = mask.nonzero().squeeze(-1)
+
+ if mask.any():
+ iatom, jatom = bond.split("-")
+ if self.cutoff_type == "cosine":
+ c_coeff = cosine_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ r_start_cos_ratio=self.r_start_cos_ratio,
+ ).flatten()
+ elif self.cutoff_type == "polynomial":
+ c_coeff = polynomial_cutoff(
+ edge_length[mask],
+ 0.5*(self.r_max_dict[iatom]+self.r_max_dict[jatom]),
+ p=self.polynomial_cutoff_p
+ ).flatten()
+
+ else:
+ # This branch is unreachable (cutoff type is checked in __init__)
+ # But TorchScript doesn't know that, so we need to make it explicitly
+ # impossible to make it past so it doesn't throw
+ # "cutoff_coeffs_all is not defined in the false branch"
+ assert False, "Invalid cutoff type"
+
+ cutoff_coeffs = torch.index_copy(cutoff_coeffs, 0, index, c_coeff)
+
+ # Determine which edges are still in play
+ prev_mask = cutoff_coeffs > 0
+ active_edges = (cutoff_coeffs > 0).nonzero().squeeze(-1)
+
+ # Compute latents
+ latents = torch.zeros(
+ (edge_sh.shape[0], self.two_body_latent.out_features),
+ dtype=edge_sh.dtype,
+ device=edge_sh.device,
+ )
+
+ new_latents = self.two_body_latent(torch.cat([
+ node_invariants[edge_center],
+ node_invariants[edge_neighbor],
+ edge_invariants,
+ ], dim=-1)[prev_mask])
+
+ # Apply cutoff, which propagates through to everything else
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+ weights = self.env_embed_mlp(latents[active_edges])
+
+ # embed initial edge
+ features = self._env_weighter(
+ edge_sh[prev_mask], weights
+ ) # features is edge_attr
+ # features = self.bn(features)
+
+ return latents, features, cutoff_coeffs, active_edges # the radial embedding x and the sperical hidden V
+
+class Layer(torch.nn.Module):
+ def __init__(
+ self,
+ # required params
+ num_types: int,
+ avg_num_neighbors: Optional[float] = None,
+ irreps_sh: o3.Irreps=None,
+ irreps_in: o3.Irreps=None,
+ irreps_out: o3.Irreps=None,
+ # general hyperparameters:
+ linear_after_env_embed: bool = False,
+ env_embed_multiplicity: int = 32,
+ # MLP parameters:
+ latent_kwargs={
+ "mlp_latent_dimensions": [128, 256, 512, 1024],
+ "mlp_nonlinearity": "silu",
+ "mlp_initialization": "uniform"
+ },
+ latent_in: int=1024,
+ latent_resnet: bool = True,
+ latent_resnet_update_ratios: Optional[List[float]] = None,
+ latent_resnet_update_ratios_learnable: bool = False,
+ ):
+ super().__init__()
+ SCALAR = o3.Irrep("0e")
+ self.latent_resnet = latent_resnet
+ self.avg_num_neighbors = avg_num_neighbors
+ self.linear_after_env_embed = linear_after_env_embed
+ self.irreps_in = irreps_in
+ self.irreps_out = irreps_out
+
+ assert all(mul==1 for mul, _ in irreps_sh)
+
+ # for normalization of env embed sums
+ # one per layer
+ self.register_buffer(
+ "env_sum_normalizations",
+ # dividing by sqrt(N)
+ torch.as_tensor(avg_num_neighbors).rsqrt(),
+ )
+
+ latent = functools.partial(ScalarMLPFunction, **latent_kwargs)
+
+ self.latents = None
+ self.env_embed_mlps = None
+ self.tps = None
+ self.linears = None
+ self.env_linears = None
+
+ # Prune impossible paths
+ self.irreps_out = o3.Irreps(
+ [
+ (mul, ir)
+ for mul, ir in self.irreps_out
+ if tp_path_exists(irreps_sh, irreps_in, ir)
+ ]
+ )
+
+ mul_irreps_sh = o3.Irreps([(env_embed_multiplicity, ir) for _, ir in irreps_sh])
+ self._env_weighter = Linear(
+ irreps_in=irreps_sh,
+ irreps_out=mul_irreps_sh,
+ internal_weights=False,
+ shared_weights=False,
+ path_normalization = "element",
+ )
+
+ # == Remove unneeded paths ==
+ #TODO: add the remove unseen paths
+
+ if self.linear_after_env_embed:
+ self.env_linears = Linear(
+ mul_irreps_sh,
+ mul_irreps_sh,
+ shared_weights=True,
+ internal_weights=True,
+ )
+ else:
+ self.env_linears = torch.nn.Identity()
+
+ # # Make TP
+ # tmp_i_out: int = 0
+ # instr = []
+ # n_scalar_outs: int = 0
+ # n_scalar_mul = []
+ # full_out_irreps = []
+ # for i_out, (mul_out, ir_out) in enumerate(self.irreps_out):
+ # for i_1, (mul1, ir_1) in enumerate(self.irreps_in): # what if feature_irreps_in has mul?
+ # for i_2, (mul2, ir_2) in enumerate(self._env_weighter.irreps_out):
+ # if ir_out in ir_1 * ir_2:
+ # if ir_out == SCALAR:
+ # n_scalar_outs += 1
+ # n_scalar_mul.append(mul2)
+ # # assert mul_out == mul1 == mul2
+ # instr.append((i_1, i_2, tmp_i_out, 'uvv', True))
+ # full_out_irreps.append((mul2, ir_out))
+ # assert full_out_irreps[-1][0] == mul2
+ # tmp_i_out += 1
+ # full_out_irreps = o3.Irreps(full_out_irreps)
+ # assert all(ir == SCALAR for _, ir in full_out_irreps[:n_scalar_outs])
+ # self.n_scalar_mul = sum(n_scalar_mul)
+
+ self.lin_pre = Linear(
+ irreps_in=self.irreps_in,
+ irreps_out=self.irreps_in,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # self.tp = TensorProduct(
+ # irreps_in1=o3.Irreps(
+ # [(mul, ir) for mul, ir in self.irreps_in]
+ # ),
+ # irreps_in2=o3.Irreps(
+ # [(mul, ir) for mul, ir in self._env_weighter.irreps_out]
+ # ),
+ # irreps_out=o3.Irreps(
+ # [(mul, ir) for mul, ir in full_out_irreps]
+ # ),
+ # irrep_normalization="component",
+ # instructions=instr,
+ # shared_weights=True,
+ # internal_weights=True,
+ # )
+ # build activation
+
+ irreps_scalar = o3.Irreps(str(self.irreps_out[0]))
+ irreps_gated = o3.Irreps([(mul, ir) for mul, ir in self.irreps_out if ir.l > 0]).simplify()
+ irreps_gates = o3.Irreps([(mul, (0,1)) for mul, _ in irreps_gated]).simplify()
+ act={1: torch.nn.functional.silu, -1: torch.tanh}
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+
+ self.activation = Gate(
+ irreps_scalar, [act[ir.p] for _, ir in irreps_scalar], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ self.tp = SeparateWeightTensorProduct(
+ irreps_in1=self._env_weighter.irreps_out+self.irreps_in+self._env_weighter.irreps_out,
+ irreps_in2=irreps_sh,
+ irreps_out=self.activation.irreps_in,
+ )
+
+ # self.sc = FullyConnectedTensorProduct(
+ # irreps_in,
+ # o3.Irreps(str(2*num_types)+"x0e"),
+ # self.irreps_out,
+ # shared_weights=True,
+ # internal_weights=True
+ # )
+
+ self.lin_post = Linear(
+ self.irreps_out,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ self.bn = BatchNorm(
+ irreps=self.irreps_out,
+ affine=True,
+ instance=False,
+ normalization="component",
+ )
+
+ self.linear_res = Linear(
+ self.irreps_in,
+ self.irreps_out,
+ shared_weights=True,
+ internal_weights=True,
+ biases=True,
+ )
+
+ # we extract the scalars from the first irrep of the tp
+ # assert full_out_irreps[0].ir == SCALAR
+ # self.linears = Linear(
+ # irreps_in=full_out_irreps,
+ # irreps_out=self.activation.irreps_in,
+ # shared_weights=True,
+ # internal_weights=True,
+ # biases=True,
+ # )
+
+ # the embedded latent invariants from the previous layer(s)
+ # and the invariants extracted from the last layer's TP:
+ self.latents = latent(
+ mlp_input_dimension=latent_in+self.irreps_out[0].dim,
+ mlp_output_dimension=None,
+ )
+
+ # the env embed MLP takes the last latent's output as input
+ # and outputs enough weights for the env embedder
+ self.env_embed_mlps = ScalarMLPFunction(
+ mlp_input_dimension=latent_in,
+ mlp_latent_dimensions=[],
+ mlp_output_dimension=self._env_weighter.weight_numel,
+ )
+ # - layer resnet update weights -
+ if latent_resnet_update_ratios is None:
+ # We initialize to zeros, which under the sigmoid() become 0.5
+ # so 1/2 * layer_1 + 1/4 * layer_2 + ...
+ # note that the sigmoid of these are the factor _between_ layers
+ # so the first entry is the ratio for the latent resnet of the first and second layers, etc.
+ # e.g. if there are 3 layers, there are 2 ratios: l1:l2, l2:l3
+ latent_resnet_update_params = torch.zeros(1)
+ else:
+ latent_resnet_update_ratios = torch.as_tensor(
+ latent_resnet_update_ratios, dtype=torch.get_default_dtype()
+ )
+ assert latent_resnet_update_ratios > 0.0
+ assert latent_resnet_update_ratios < 1.0
+ latent_resnet_update_params = torch.special.logit(
+ latent_resnet_update_ratios
+ )
+ # The sigmoid is mostly saturated at ±6, keep it in a reasonable range
+ latent_resnet_update_params.clamp_(-6.0, 6.0)
+
+ if latent_resnet_update_ratios_learnable:
+ self._latent_resnet_update_params = torch.nn.Parameter(
+ latent_resnet_update_params
+ )
+ else:
+ self.register_buffer(
+ "_latent_resnet_update_params", latent_resnet_update_params
+ )
+
+ def forward(self, edge_index, edge_sh, atom_type, latents, features, cutoff_coeffs, active_edges):
+ # update V
+ # update X
+ # edge_index: [2, num_edges]
+ # irreps_sh: [num_edges, irreps_sh]
+ # latents: [num_edges, latent_in]
+ # fetures: [num_active_edges, in_irreps]
+ # cutoff_coeffs: [num_edges]
+ # active_edges: [num_active_edges]
+
+ edge_center = edge_index[0]
+ edge_neighbor = edge_index[1]
+
+ prev_mask = cutoff_coeffs > 0
+
+ # sc_features = self.sc(features, node_one_hot[edge_index].transpose(0,1).flatten(1,2)[active_edges])
+ # update V
+ weights = self.env_embed_mlps(latents[active_edges])
+
+ # Build the local environments
+ # This local environment should only be a sum over neighbors
+ # who are within the cutoff of the _current_ layer
+ # Those are the active edges, which are the only ones we
+ # have weights for (env_w) anyway.
+ # So we mask out the edges in the sum:
+ local_env_per_edge = scatter(
+ self._env_weighter(edge_sh[active_edges], weights),
+ edge_center[active_edges],
+ dim=0,
+ )
+
+ # currently, we have a sum over neighbors of constant number for each layer,
+ # the env_sum_normalization can be a scalar or list
+ # the different cutoff can be added in the future
+
+ if self.env_sum_normalizations.ndim < 1:
+ norm_const = self.env_sum_normalizations
+ else:
+ norm_const = self.env_sum_normalizations[atom_type.flatten()].unsqueeze(-1)
+
+ local_env_per_edge = local_env_per_edge * norm_const
+ local_env_per_edge = self.env_linears(local_env_per_edge)
+
+ # local_env_per_edge = torch.cat([local_env_per_edge[edge_center[active_edges]], local_env_per_edge[edge_neighbor[active_edges]]], dim=-1)
+ # local_env_per_edge = local_env_per_edge[edge_center[active_edges]]
+ # Now do the TP
+ # recursively tp current features with the environment embeddings
+ new_features = self.tp(
+ torch.cat(
+ [
+ local_env_per_edge[edge_center[active_edges]],
+ self.lin_pre(features),
+ local_env_per_edge[edge_neighbor[active_edges]]
+ ], dim=-1),
+ edge_sh[active_edges]) # full_out_irreps
+
+ new_features = self.activation(new_features)
+ # # do the linear
+ # new_features = self.linears(new_features)
+
+
+ # features has shape [N_edge, full_feature_out.dim]
+ # we know scalars are first
+ scalars = new_features[:, :self.irreps_out[0].dim]
+ assert len(scalars.shape) == 2
+
+ new_features = self.lin_post(new_features)
+
+ new_features = self.bn(new_features)
+
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ features = coefficient_new * new_features + coefficient_old * self.linear_res(features)
+ else:
+ features = new_features
+
+ # update X
+ latent_inputs_to_cat = [
+ latents[active_edges],
+ scalars,
+ ]
+
+ new_latents = self.latents(torch.cat(latent_inputs_to_cat, dim=-1))
+ new_latents = cutoff_coeffs[active_edges].unsqueeze(-1) * new_latents
+ # At init, we assume new and old to be approximately uncorrelated
+ # Thus their variances add
+ # we always want the latent space to be normalized to variance = 1.0,
+ # because it is critical for learnability. Still, we want to preserve
+ # the _relative_ magnitudes of the current latent and the residual update
+ # to be controled by `this_layer_update_coeff`
+ # Solving the simple system for the two coefficients:
+ # a^2 + b^2 = 1 (variances add) & a * this_layer_update_coeff = b
+ # gives
+ # a = 1 / sqrt(1 + this_layer_update_coeff^2) & b = this_layer_update_coeff / sqrt(1 + this_layer_update_coeff^2)
+ # rsqrt is reciprocal sqrt
+ if self.latent_resnet:
+ update_coefficients = self._latent_resnet_update_params.sigmoid()
+ coefficient_old = torch.rsqrt(update_coefficients.square() + 1)
+ coefficient_new = update_coefficients * coefficient_old
+ latents = torch.index_add(
+ coefficient_old * latents,
+ 0,
+ active_edges,
+ coefficient_new * new_latents,
+ )
+ else:
+ latents = torch.index_copy(latents, 0, active_edges, new_latents)
+
+ return latents, features, cutoff_coeffs, active_edges
+
\ No newline at end of file
diff --git a/dptb/nn/embedding/emb.py b/dptb/nn/embedding/emb.py
new file mode 100644
index 00000000..85f976bd
--- /dev/null
+++ b/dptb/nn/embedding/emb.py
@@ -0,0 +1,28 @@
+import torch.nn as nn
+import torch
+from dptb.utils.register import Register
+
+"""this is the register class for descriptors
+
+all descriptors inplemendeted should be a instance of nn.Module class, and provide a forward function that
+takes AtomicData class as input, and give AtomicData class as output.
+
+"""
+class Embedding:
+ _register = Register()
+
+ def register(target):
+ return Embedding._register.register(target)
+
+ def __new__(cls, method: str, **kwargs):
+ if method in Embedding._register.keys():
+ return Embedding._register[method](**kwargs)
+ else:
+ raise Exception(f"Descriptor mode: {method} is not registered!")
+
+
+
+
+
+
+
diff --git a/dptb/nnsktb/__init__.py b/dptb/nn/embedding/from_deephe3/__init__.py
similarity index 100%
rename from dptb/nnsktb/__init__.py
rename to dptb/nn/embedding/from_deephe3/__init__.py
diff --git a/dptb/nn/embedding/from_deephe3/deephe3.py b/dptb/nn/embedding/from_deephe3/deephe3.py
new file mode 100644
index 00000000..1a8023da
--- /dev/null
+++ b/dptb/nn/embedding/from_deephe3/deephe3.py
@@ -0,0 +1,469 @@
+import warnings
+import os
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch_scatter import scatter
+from e3nn.nn import Gate
+from e3nn.o3 import Irreps, Linear, SphericalHarmonics, FullyConnectedTensorProduct
+import e3nn.o3 as o3
+from ...radial_basis import GaussianBasis
+from .e3module import SphericalBasis, sort_irreps, e3LayerNorm, e3ElementWise, SkipConnection, SeparateWeightTensorProduct, SelfTp
+from dptb.data import AtomicDataDict
+
+epsilon = 1e-8
+
+def tp_path_exists(irreps_in1, irreps_in2, ir_out):
+ irreps_in1 = o3.Irreps(irreps_in1).simplify()
+ irreps_in2 = o3.Irreps(irreps_in2).simplify()
+ ir_out = o3.Irrep(ir_out)
+
+ for _, ir1 in irreps_in1:
+ for _, ir2 in irreps_in2:
+ if ir_out in ir1 * ir2:
+ return True
+ return False
+
+def get_gate_nonlin(irreps_in1, irreps_in2, irreps_out,
+ act={1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates={1: torch.sigmoid, -1: torch.tanh}
+ ):
+ # get gate nonlinearity after tensor product
+ # irreps_in1 and irreps_in2 are irreps to be multiplied in tensor product
+ # irreps_out is desired irreps after gate nonlin
+ # notice that nonlin.irreps_out might not be exactly equal to irreps_out
+
+ irreps_scalars = Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l == 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ irreps_gated = Irreps([
+ (mul, ir)
+ for mul, ir in irreps_out
+ if ir.l > 0 and tp_path_exists(irreps_in1, irreps_in2, ir)
+ ]).simplify()
+ if irreps_gated.dim > 0:
+ if tp_path_exists(irreps_in1, irreps_in2, "0e"):
+ ir = "0e"
+ elif tp_path_exists(irreps_in1, irreps_in2, "0o"):
+ ir = "0o"
+ warnings.warn('Using odd representations as gates')
+ else:
+ raise ValueError(
+ f"irreps_in1={irreps_in1} times irreps_in2={irreps_in2} is unable to produce gates needed for irreps_gated={irreps_gated}")
+ else:
+ ir = None
+ irreps_gates = Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()
+
+ gate_nonlin = Gate(
+ irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar
+ irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
+ irreps_gated # gated tensors
+ )
+
+ return gate_nonlin
+
+
+class EquiConv(nn.Module):
+ def __init__(self, fc_len_in, irreps_in1, irreps_in2, irreps_out, norm='', nonlin=True,
+ act = {1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates = {1: torch.sigmoid, -1: torch.tanh}
+ ):
+ super(EquiConv, self).__init__()
+
+ irreps_in1 = Irreps(irreps_in1)
+ irreps_in2 = Irreps(irreps_in2)
+ irreps_out = Irreps(irreps_out)
+
+ self.nonlin = None
+ if nonlin:
+ self.nonlin = get_gate_nonlin(irreps_in1, irreps_in2, irreps_out, act, act_gates)
+ irreps_tp_out = self.nonlin.irreps_in
+ else:
+ irreps_tp_out = Irreps([(mul, ir) for mul, ir in irreps_out if tp_path_exists(irreps_in1, irreps_in2, ir)])
+
+ self.tp = SeparateWeightTensorProduct(irreps_in1, irreps_in2, irreps_tp_out)
+
+ if nonlin:
+ self.cfconv = e3ElementWise(self.nonlin.irreps_out)
+ self.irreps_out = self.nonlin.irreps_out
+ else:
+ self.cfconv = e3ElementWise(irreps_tp_out)
+ self.irreps_out = irreps_tp_out
+
+ # fully connected net to create tensor product weights
+ linear_act = nn.SiLU()
+ self.fc = nn.Sequential(nn.Linear(fc_len_in, 64),
+ linear_act,
+ nn.Linear(64, 64),
+ linear_act,
+ nn.Linear(64, self.cfconv.len_weight)
+ )
+
+ self.norm = None
+ if norm:
+ if norm == 'e3LayerNorm':
+ self.norm = e3LayerNorm(self.cfconv.irreps_in)
+ else:
+ raise ValueError(f'unknown norm: {norm}')
+
+ def forward(self, fea_in1, fea_in2, fea_weight, batch_edge):
+ z = self.tp(fea_in1, fea_in2)
+
+ if self.nonlin is not None:
+ z = self.nonlin(z)
+
+ weight = self.fc(fea_weight)
+ z = self.cfconv(z, weight)
+
+ if self.norm is not None:
+ z = self.norm(z, batch_edge)
+
+ # TODO self-connection here
+ return z
+
+
+class NodeUpdateBlock(nn.Module):
+ def __init__(self, num_species, fc_len_in, irreps_sh, irreps_in_node, irreps_out_node, irreps_in_edge,
+ act, act_gates, use_selftp=False, use_sc=True, concat=True, only_ij=False, nonlin=False, norm='e3LayerNorm', if_sort_irreps=False):
+ super(NodeUpdateBlock, self).__init__()
+ irreps_in_node = Irreps(irreps_in_node)
+ irreps_sh = Irreps(irreps_sh)
+ irreps_out_node = Irreps(irreps_out_node)
+ irreps_in_edge = Irreps(irreps_in_edge)
+
+ if concat:
+ irreps_in1 = irreps_in_node + irreps_in_node + irreps_in_edge
+ if if_sort_irreps:
+ self.sort = sort_irreps(irreps_in1)
+ irreps_in1 = self.sort.irreps_out
+ else:
+ irreps_in1 = irreps_in_node
+ irreps_in2 = irreps_sh
+
+ self.lin_pre = Linear(irreps_in=irreps_in_node, irreps_out=irreps_in_node, biases=True)
+
+ self.nonlin = None
+ if nonlin:
+ self.nonlin = get_gate_nonlin(irreps_in1, irreps_in2, irreps_out_node, act, act_gates)
+ irreps_conv_out = self.nonlin.irreps_in
+ conv_nonlin = False
+ else:
+ irreps_conv_out = irreps_out_node
+ conv_nonlin = True
+
+ self.conv = EquiConv(fc_len_in, irreps_in1, irreps_in2, irreps_conv_out, nonlin=conv_nonlin, act=act, act_gates=act_gates)
+ self.lin_post = Linear(irreps_in=self.conv.irreps_out, irreps_out=self.conv.irreps_out, biases=True)
+
+ if nonlin:
+ self.irreps_out = self.nonlin.irreps_out
+ else:
+ self.irreps_out = self.conv.irreps_out
+
+ self.sc = None
+ if use_sc:
+ self.sc = FullyConnectedTensorProduct(irreps_in_node, f'{num_species}x0e', self.conv.irreps_out)
+
+ self.norm = None
+ if norm:
+ if norm == 'e3LayerNorm':
+ self.norm = e3LayerNorm(self.irreps_out)
+ else:
+ raise ValueError(f'unknown norm: {norm}')
+
+ self.skip_connect = SkipConnection(irreps_in_node, self.irreps_out)
+
+ self.self_tp = None
+ if use_selftp:
+ self.self_tp = SelfTp(self.irreps_out, self.irreps_out)
+
+ self.irreps_in_node = irreps_in_node
+ self.use_sc = use_sc
+ self.concat = concat
+ self.only_ij = only_ij
+ self.if_sort_irreps = if_sort_irreps
+
+ def forward(self, node_fea, node_one_hot, edge_sh, edge_fea, edge_length_embedded, edge_index, batch, selfloop_edge, edge_length):
+
+ node_fea_old = node_fea
+
+ if self.use_sc:
+ node_self_connection = self.sc(node_fea, node_one_hot)
+
+ node_fea = self.lin_pre(node_fea)
+
+ index_i = edge_index[0]
+ index_j = edge_index[1]
+ if self.concat:
+ fea_in = torch.cat([node_fea[index_i], node_fea[index_j], edge_fea], dim=-1)
+ if self.if_sort_irreps:
+ fea_in = self.sort(fea_in)
+ edge_update = self.conv(fea_in, edge_sh, edge_length_embedded, batch[edge_index[0]])
+ else:
+ edge_update = self.conv(node_fea[index_j], edge_sh, edge_length_embedded, batch[edge_index[0]])
+
+ # sigma = 3
+ # n = 2
+ # edge_update = edge_update * torch.exp(- edge_length ** n / sigma ** n / 2).view(-1, 1)
+
+ node_fea = scatter(edge_update, index_i, dim=0, dim_size=node_fea.shape[0], reduce='add')
+ if self.only_ij:
+ node_fea = node_fea + scatter(edge_update[~selfloop_edge], index_j[~selfloop_edge], dim=0, dim_size=node_fea.shape[0], reduce='add')
+
+ node_fea = self.lin_post(node_fea)
+
+ if self.use_sc:
+ node_fea = node_fea + node_self_connection
+
+ if self.nonlin is not None:
+ node_fea = self.nonlin(node_fea)
+
+ if self.norm is not None:
+ node_fea = self.norm(node_fea, batch)
+
+ node_fea = self.skip_connect(node_fea_old, node_fea)
+
+ if self.self_tp is not None:
+ node_fea = self.self_tp(node_fea)
+
+ return node_fea
+
+
+class EdgeUpdateBlock(nn.Module):
+ def __init__(self, num_species, fc_len_in, irreps_sh, irreps_in_node, irreps_in_edge, irreps_out_edge,
+ act, act_gates, use_selftp=False, use_sc=True, init_edge=False, nonlin=False, norm='e3LayerNorm', if_sort_irreps=False):
+ super(EdgeUpdateBlock, self).__init__()
+ irreps_in_node = Irreps(irreps_in_node)
+ irreps_in_edge = Irreps(irreps_in_edge)
+ irreps_out_edge = Irreps(irreps_out_edge)
+
+ irreps_in1 = irreps_in_node + irreps_in_node + irreps_in_edge
+ if if_sort_irreps:
+ self.sort = sort_irreps(irreps_in1)
+ irreps_in1 = self.sort.irreps_out
+ irreps_in2 = irreps_sh
+
+ self.lin_pre = Linear(irreps_in=irreps_in_edge, irreps_out=irreps_in_edge, biases=True)
+
+ self.nonlin = None
+ self.lin_post = None
+ if nonlin:
+ self.nonlin = get_gate_nonlin(irreps_in1, irreps_in2, irreps_out_edge, act, act_gates)
+ irreps_conv_out = self.nonlin.irreps_in
+ conv_nonlin = False
+ else:
+ irreps_conv_out = irreps_out_edge
+ conv_nonlin = True
+
+ self.conv = EquiConv(fc_len_in, irreps_in1, irreps_in2, irreps_conv_out, nonlin=conv_nonlin, act=act, act_gates=act_gates)
+ self.lin_post = Linear(irreps_in=self.conv.irreps_out, irreps_out=self.conv.irreps_out, biases=True)
+
+ if use_sc:
+ self.sc = FullyConnectedTensorProduct(irreps_in_edge, f'{num_species**2}x0e', self.conv.irreps_out)
+
+ if nonlin:
+ self.irreps_out = self.nonlin.irreps_out
+ else:
+ self.irreps_out = self.conv.irreps_out
+
+ self.norm = None
+ if norm:
+ if norm == 'e3LayerNorm':
+ self.norm = e3LayerNorm(self.irreps_out)
+ else:
+ raise ValueError(f'unknown norm: {norm}')
+
+ self.skip_connect = SkipConnection(irreps_in_edge, self.irreps_out) # ! consider init_edge
+
+ self.self_tp = None
+ if use_selftp:
+ self.self_tp = SelfTp(self.irreps_out, self.irreps_out)
+
+ self.use_sc = use_sc
+ self.init_edge = init_edge
+ self.if_sort_irreps = if_sort_irreps
+ self.irreps_in_edge = irreps_in_edge
+
+ def forward(self, node_fea, edge_one_hot, edge_sh, edge_fea, edge_length_embedded, edge_index, batch):
+
+ if not self.init_edge:
+ edge_fea_old = edge_fea
+ if self.use_sc:
+ edge_self_connection = self.sc(edge_fea, edge_one_hot)
+ edge_fea = self.lin_pre(edge_fea)
+
+ index_i = edge_index[0]
+ index_j = edge_index[1]
+ fea_in = torch.cat([node_fea[index_i], node_fea[index_j], edge_fea], dim=-1)
+ if self.if_sort_irreps:
+ fea_in = self.sort(fea_in)
+ edge_fea = self.conv(fea_in, edge_sh, edge_length_embedded, batch[edge_index[0]])
+
+ edge_fea = self.lin_post(edge_fea)
+
+ if self.use_sc:
+ edge_fea = edge_fea + edge_self_connection
+
+ if self.nonlin is not None:
+ edge_fea = self.nonlin(edge_fea)
+
+ if self.norm is not None:
+ edge_fea = self.norm(edge_fea, batch[edge_index[0]])
+
+ if not self.init_edge:
+ edge_fea = self.skip_connect(edge_fea_old, edge_fea)
+
+ if self.self_tp is not None:
+ edge_fea = self.self_tp(edge_fea)
+
+ return edge_fea
+
+
+class Net(nn.Module):
+ def __init__(self, num_species, # spherical basis irreps
+ irreps_embed_node, irreps_sh, irreps_mid_node, irreps_post_node, irreps_out_node,
+ irreps_edge_init, irreps_mid_edge, irreps_post_edge, irreps_out_edge,
+ num_block, r_max, use_sc=True, no_parity=False, use_sbf=True, selftp=False, edge_upd=True,
+ only_ij=False, num_basis=128,
+ act={1: torch.nn.functional.silu, -1: torch.tanh},
+ act_gates={1: torch.sigmoid, -1: torch.tanh},
+ if_sort_irreps=False):
+
+ if no_parity:
+ for irreps in (irreps_embed_node, irreps_edge_init, irreps_sh, irreps_mid_node,
+ irreps_post_node, irreps_out_node,irreps_mid_edge, irreps_post_edge, irreps_out_edge,):
+ for _, ir in Irreps(irreps):
+ assert ir.p == 1, 'Ignoring parity but requiring representations with odd parity in net'
+
+ super(Net, self).__init__()
+ self.num_species = num_species
+ self.only_ij = only_ij
+
+ irreps_embed_node = Irreps(irreps_embed_node)
+ assert irreps_embed_node == Irreps(f'{irreps_embed_node.dim}x0e')
+ self.embedding = Linear(irreps_in=f"{num_species}x0e", irreps_out=irreps_embed_node) # node embedding
+
+ # edge embedding for tensor product weight
+ # self.basis = BesselBasis(r_max, num_basis=num_basis, trainable=False)
+ # self.cutoff = PolynomialCutoff(r_max, p=6)
+ self.basis = GaussianBasis(start=0.0, stop=r_max, n_gaussians=num_basis, trainable=False)
+
+ # distance expansion to initialize edge feature
+ irreps_edge_init = Irreps(irreps_edge_init)
+ assert irreps_edge_init == Irreps(f'{irreps_edge_init.dim}x0e')
+ self.distance_expansion = GaussianBasis(
+ start=0.0, stop=6.0, n_gaussians=irreps_edge_init.dim, trainable=False
+ )
+
+ if use_sbf:
+ self.sh = SphericalBasis(irreps_sh, r_max)
+ else:
+ self.sh = SphericalHarmonics(
+ irreps_out=irreps_sh,
+ normalize=True,
+ normalization='component',
+ )
+ self.use_sbf = use_sbf
+ if no_parity:
+ irreps_sh = Irreps([(mul, (ir.l, 1)) for mul, ir in Irreps(irreps_sh)])
+ self.irreps_sh = irreps_sh
+
+ # self.edge_update_block_init = EdgeUpdateBlock(num_basis, irreps_sh, self.embedding.irreps_out, None, irreps_mid_edge, act, act_gates, False, init_edge=True)
+ irreps_node_prev = self.embedding.irreps_out
+ irreps_edge_prev = irreps_edge_init
+
+ self.node_update_blocks = nn.ModuleList([])
+ self.edge_update_blocks = nn.ModuleList([])
+ for index_block in range(num_block):
+ if index_block == num_block - 1:
+ node_update_block = NodeUpdateBlock(num_species, num_basis, irreps_sh, irreps_node_prev, irreps_post_node, irreps_edge_prev, act, act_gates, use_selftp=selftp, use_sc=use_sc, only_ij=only_ij, if_sort_irreps=if_sort_irreps)
+ edge_update_block = EdgeUpdateBlock(num_species, num_basis, irreps_sh, node_update_block.irreps_out, irreps_edge_prev, irreps_post_edge, act, act_gates, use_selftp=selftp, use_sc=use_sc, if_sort_irreps=if_sort_irreps)
+ else:
+ node_update_block = NodeUpdateBlock(num_species, num_basis, irreps_sh, irreps_node_prev, irreps_mid_node, irreps_edge_prev, act, act_gates, use_selftp=False, use_sc=use_sc, only_ij=only_ij, if_sort_irreps=if_sort_irreps)
+ edge_update_block = None
+ if edge_upd:
+ edge_update_block = EdgeUpdateBlock(num_species, num_basis, irreps_sh, node_update_block.irreps_out, irreps_edge_prev, irreps_mid_edge, act, act_gates, use_selftp=False, use_sc=use_sc, if_sort_irreps=if_sort_irreps)
+ irreps_node_prev = node_update_block.irreps_out
+ if edge_update_block is not None:
+ irreps_edge_prev = edge_update_block.irreps_out
+ self.node_update_blocks.append(node_update_block)
+ self.edge_update_blocks.append(edge_update_block)
+
+ irreps_out_edge = Irreps(irreps_out_edge)
+ for _, ir in irreps_out_edge:
+ assert ir in irreps_edge_prev, f'required ir {ir} in irreps_out_edge cannot be produced by convolution in the last edge update block ({edge_update_block.irreps_in_edge} -> {edge_update_block.irreps_out})'
+
+ self.irreps_out_node = irreps_out_node
+ self.irreps_out_edge = irreps_out_edge
+ self.lin_node = Linear(irreps_in=irreps_node_prev, irreps_out=irreps_out_node, biases=True)
+ self.lin_edge = Linear(irreps_in=irreps_edge_prev, irreps_out=irreps_out_edge, biases=True)
+
+ def forward(self, data):
+ node_one_hot = F.one_hot(data[AtomicDataDict.ATOM_TYPE_KEY].flatten(), num_classes=self.num_species).type(torch.get_default_dtype())
+ edge_one_hot = F.one_hot(self.num_species * data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[data[AtomicDataDict.EDGE_INDEX_KEY][0]] + data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[data[AtomicDataDict.EDGE_INDEX_KEY][1]],
+ num_classes=self.num_species**2).type(torch.get_default_dtype()) # ! might not be good if dataset has many elements
+ # env_one_hot = F.one_hot(self.num_species * data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[data[AtomicDataDict.ENV_INDEX_KEY][0]] + data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[data[AtomicDataDict.ENV_INDEX_KEY][1]],
+ # num_classes=self.num_species**2).type(torch.get_default_dtype()) # ! might not be good if dataset has many elements
+
+ node_fea = self.embedding(node_one_hot)
+
+ edge_length = data[AtomicDataDict.EDGE_LENGTH_KEY]
+ edge_vec = torch.cat([edge_length.reshape(-1, 1), data[AtomicDataDict.EDGE_VECTORS_KEY][:, [1, 2, 0]]], dim=-1) # (y, z, x) order
+ # env_length = data[AtomicDataDict.ENV_LENGTH_KEY]
+ # env_vec = torch.cat([env_length.reshape(-1, 1), data[AtomicDataDict.ENV_VECTORS_KEY][:, [1, 2, 0]]], dim=-1) # (y, z, x) order
+
+ if self.use_sbf:
+ edge_sh = self.sh(edge_length, edge_vec)
+ # env_sh = self.sh(env_length, env_vec)
+ else:
+ edge_sh = self.sh(edge_vec).type(torch.get_default_dtype())
+ # env_sh = self.sh(env_vec).type(torch.get_default_dtype())
+ # edge_length_embedded = (self.basis(data["edge_attr"][:, 0] + epsilon) * self.cutoff(data["edge_attr"][:, 0])[:, None]).type(torch.get_default_dtype())
+ edge_length_embedded = self.basis(edge_length)
+ # env_length_embedded = self.basis(env_length)
+
+ selfloop_edge = None
+ if self.only_ij:
+ selfloop_edge = edge_length < 1e-7
+
+ # edge_fea = self.edge_update_block_init(node_fea, edge_sh, None, edge_length_embedded, data["edge_index"])
+ edge_fea = self.distance_expansion(edge_length).type(torch.get_default_dtype())
+ # env_fea = self.distance_expansion(env_length).type(torch.get_default_dtype())
+ for node_update_block, edge_update_block in zip(self.node_update_blocks, self.edge_update_blocks):
+ node_fea = node_update_block(node_fea, node_one_hot, edge_sh, edge_fea, edge_length_embedded, data[AtomicDataDict.EDGE_INDEX_KEY], data[AtomicDataDict.BATCH_KEY], selfloop_edge, edge_length)
+ if edge_update_block is not None:
+ edge_fea = edge_update_block(node_fea, edge_one_hot, edge_sh, edge_fea, edge_length_embedded, data[AtomicDataDict.EDGE_INDEX_KEY], data[AtomicDataDict.BATCH_KEY])
+ # env_fea = edge_update_block(node_fea, env_one_hot, env_sh, env_fea, env_length_embedded, data[AtomicDataDict.ENV_INDEX_KEY], data[AtomicDataDict.BATCH_KEY])
+
+ node_fea = self.lin_node(node_fea)
+ edge_fea = self.lin_edge(edge_fea)
+
+ return node_fea, edge_fea
+
+ def __repr__(self):
+ info = '===== DeepH-E3 model structure: ====='
+ if self.use_sbf:
+ info += f'\nusing spherical bessel basis: {self.irreps_sh}'
+ else:
+ info += f'\nusing spherical harmonics: {self.irreps_sh}'
+ for index, (nupd, eupd) in enumerate(zip(self.node_update_blocks, self.edge_update_blocks)):
+ info += f'\n=== layer {index} ==='
+ info += f'\nnode update: ({nupd.irreps_in_node} -> {nupd.irreps_out})'
+ if eupd is not None:
+ info += f'\nedge update: ({eupd.irreps_in_edge} -> {eupd.irreps_out})'
+ info += '\n=== output ==='
+ info += f'\noutput node: ({self.irreps_out_node})'
+ info += f'\noutput edge: ({self.irreps_out_edge})'
+
+ return info
+
+ def analyze_tp(self, path):
+ os.makedirs(path, exist_ok=True)
+ for index, (nupd, eupd) in enumerate(zip(self.node_update_blocks, self.edge_update_blocks)):
+ fig, ax = nupd.conv.tp.visualize()
+ fig.savefig(os.path.join(path, f'node_update_{index}.png'))
+ fig.clf()
+ fig, ax = eupd.conv.tp.visualize()
+ fig.savefig(os.path.join(path, f'edge_update_{index}.png'))
+ fig.clf()
diff --git a/dptb/nn/embedding/from_deephe3/e3module.py b/dptb/nn/embedding/from_deephe3/e3module.py
new file mode 100644
index 00000000..60de7112
--- /dev/null
+++ b/dptb/nn/embedding/from_deephe3/e3module.py
@@ -0,0 +1,367 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch_scatter import scatter
+from torch_geometric.utils import degree
+from scipy.optimize import brentq
+from scipy import special as sp
+from e3nn.util.jit import compile_mode
+from e3nn.o3 import Irrep, Irreps, wigner_3j, matrix_to_angles, Linear, FullyConnectedTensorProduct, TensorProduct, SphericalHarmonics
+from e3nn.nn import Extract
+import numpy as np
+from typing import Union
+import e3nn.o3 as o3
+from ...cutoff import polynomial_cutoff
+import sympy as sym
+
+
+def spherical_bessel_formulas(n):
+ """
+ Computes the sympy formulas for the spherical bessel functions up to order n (excluded)
+ """
+ x = sym.symbols('x')
+
+ f = [sym.sin(x)/x]
+ a = sym.sin(x)/x
+ for i in range(1, n):
+ b = sym.diff(a, x)/x
+ f += [sym.simplify(b*(-x)**i)]
+ a = sym.simplify(b)
+ return f
+
+def Jn(r, n):
+ """
+ numerical spherical bessel functions of order n
+ """
+ return np.sqrt(np.pi/(2*r)) * sp.jv(n+0.5, r)
+
+
+def Jn_zeros(n, k):
+ """
+ Compute the first k zeros of the spherical bessel functions up to order n (excluded)
+ """
+ zerosj = np.zeros((n, k), dtype="float32")
+ zerosj[0] = np.arange(1, k + 1) * np.pi
+ points = np.arange(1, k + n) * np.pi
+ racines = np.zeros(k + n - 1, dtype="float32")
+ for i in range(1, n):
+ for j in range(k + n - 1 - i):
+ foo = brentq(Jn, points[j], points[j + 1], (i,))
+ racines[j] = foo
+ points = racines
+ zerosj[i][:k] = racines[:k]
+
+ return zerosj
+
+def bessel_basis(n, k):
+ """
+ Compute the sympy formulas for the normalized and rescaled spherical bessel functions up to
+ order n (excluded) and maximum frequency k (excluded).
+ """
+
+ zeros = Jn_zeros(n, k)
+ normalizer = []
+ for order in range(n):
+ normalizer_tmp = []
+ for i in range(k):
+ normalizer_tmp += [0.5*Jn(zeros[order, i], order+1)**2]
+ normalizer_tmp = 1/np.array(normalizer_tmp)**0.5
+ normalizer += [normalizer_tmp]
+
+ f = spherical_bessel_formulas(n)
+ x = sym.symbols('x')
+ bess_basis = []
+ for order in range(n):
+ bess_basis_tmp = []
+ for i in range(k):
+ bess_basis_tmp += [sym.simplify(normalizer[order]
+ [i]*f[order].subs(x, zeros[order, i]*x))]
+ bess_basis += [bess_basis_tmp]
+ return bess_basis
+
+class sort_irreps(torch.nn.Module):
+ def __init__(self, irreps_in):
+ super().__init__()
+ irreps_in = Irreps(irreps_in)
+ sorted_irreps = irreps_in.sort()
+
+ irreps_out_list = [((mul, ir),) for mul, ir in sorted_irreps.irreps]
+ instructions = [(i,) for i in sorted_irreps.inv]
+ self.extr = Extract(irreps_in, irreps_out_list, instructions)
+
+ irreps_in_list = [((mul, ir),) for mul, ir in irreps_in]
+ instructions_inv = [(i,) for i in sorted_irreps.p]
+ self.extr_inv = Extract(sorted_irreps.irreps, irreps_in_list, instructions_inv)
+
+ self.irreps_in = irreps_in
+ self.irreps_out = sorted_irreps.irreps.simplify()
+
+ def forward(self, x):
+ r'''irreps_in -> irreps_out'''
+ extracted = self.extr(x)
+ return torch.cat(extracted, dim=-1)
+
+ def inverse(self, x):
+ r'''irreps_out -> irreps_in'''
+ extracted_inv = self.extr_inv(x)
+ return torch.cat(extracted_inv, dim=-1)
+
+
+
+
+
+class e3LayerNorm(nn.Module):
+ def __init__(self, irreps_in, eps=1e-5, affine=True, normalization='component', subtract_mean=True, divide_norm=False):
+ super().__init__()
+
+ self.irreps_in = Irreps(irreps_in)
+ self.eps = eps
+
+ if affine:
+ ib, iw = 0, 0
+ weight_slices, bias_slices = [], []
+ for mul, ir in irreps_in:
+ if ir.is_scalar(): # bias only to 0e
+ bias_slices.append(slice(ib, ib + mul))
+ ib += mul
+ else:
+ bias_slices.append(None)
+ weight_slices.append(slice(iw, iw + mul))
+ iw += mul
+ self.weight = nn.Parameter(torch.ones([iw]))
+ self.bias = nn.Parameter(torch.zeros([ib]))
+ self.bias_slices = bias_slices
+ self.weight_slices = weight_slices
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+
+ self.subtract_mean = subtract_mean
+ self.divide_norm = divide_norm
+ assert normalization in ['component', 'norm']
+ self.normalization = normalization
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.weight is not None:
+ self.weight.data.fill_(1)
+ # nn.init.uniform_(self.weight)
+ if self.bias is not None:
+ self.bias.data.fill_(0)
+ # nn.init.uniform_(self.bias)
+
+ def forward(self, x: torch.Tensor, batch: torch.Tensor = None):
+ # input x must have shape [num_node(edge), dim]
+ # if first dimension of x is node index, then batch should be batch.batch
+ # if first dimension of x is edge index, then batch should be batch.batch[batch.edge_index[0]]
+
+ if batch is None:
+ batch = torch.full([x.shape[0]], 0, dtype=torch.int64)
+
+ # from torch_geometric.nn.norm.LayerNorm
+
+ batch_size = int(batch.max()) + 1
+ batch_degree = degree(batch, batch_size, dtype=torch.int64).clamp_(min=1).to(dtype=x.dtype)
+
+ out = []
+ ix = 0
+ for index, (mul, ir) in enumerate(self.irreps_in):
+ field = x[:, ix: ix + mul * ir.dim].reshape(-1, mul, ir.dim) # [node, mul, repr]
+
+ # compute and subtract mean
+ if self.subtract_mean or ir.l == 0: # do not subtract mean for l>0 irreps if subtract_mean=False
+ mean = scatter(field, batch, dim=0, dim_size=batch_size,
+ reduce='add').mean(dim=1, keepdim=True) / batch_degree[:, None, None] # scatter_mean does not support complex number
+ field = field - mean[batch]
+
+ # compute and divide norm
+ if self.divide_norm or ir.l == 0: # do not divide norm for l>0 irreps if divide_norm=False
+ norm = scatter(field.abs().pow(2), batch, dim=0, dim_size=batch_size,
+ reduce='mean').mean(dim=[1,2], keepdim=True) # add abs here to deal with complex numbers
+ if self.normalization == 'norm':
+ norm = norm * ir.dim
+ field = field / (norm.sqrt()[batch] + self.eps)
+
+ # affine
+ if self.weight is not None:
+ weight = self.weight[self.weight_slices[index]]
+ field = field * weight[None, :, None]
+ if self.bias is not None and ir.is_scalar():
+ bias = self.bias[self.bias_slices[index]]
+ field = field + bias[None, :, None]
+
+ out.append(field.reshape(-1, mul * ir.dim))
+ ix += mul * ir.dim
+
+ out = torch.cat(out, dim=-1)
+
+ return out
+
+class e3ElementWise:
+ def __init__(self, irreps_in):
+ self.irreps_in = Irreps(irreps_in)
+
+ len_weight = 0
+ for mul, ir in self.irreps_in:
+ len_weight += mul
+
+ self.len_weight = len_weight
+
+ def __call__(self, x: torch.Tensor, weight: torch.Tensor):
+ # x should have shape [edge/node, channels]
+ # weight should have shape [edge/node, self.len_weight]
+
+ ix = 0
+ iw = 0
+ out = []
+ for mul, ir in self.irreps_in:
+ field = x[:, ix: ix + mul * ir.dim]
+ field = field.reshape(-1, mul, ir.dim)
+ field = field * weight[:, iw: iw + mul][:, :, None]
+ field = field.reshape(-1, mul * ir.dim)
+
+ ix += mul * ir.dim
+ iw += mul
+ out.append(field)
+
+ return torch.cat(out, dim=-1)
+
+
+class SkipConnection(nn.Module):
+ def __init__(self, irreps_in, irreps_out, is_complex=False):
+ super().__init__()
+ irreps_in = Irreps(irreps_in)
+ irreps_out = Irreps(irreps_out)
+ self.sc = None
+ if irreps_in == irreps_out:
+ self.sc = None
+ else:
+ self.sc = Linear(irreps_in=irreps_in, irreps_out=irreps_out)
+
+ def forward(self, old, new):
+ if self.sc is not None:
+ old = self.sc(old)
+
+ return old + new
+
+
+class SelfTp(nn.Module):
+ def __init__(self, irreps_in, irreps_out, **kwargs):
+ '''z_i = W'_{ij}x_j W''_{ik}x_k (k>=j)'''
+ super().__init__()
+
+ assert not kwargs.pop('internal_weights', False) # internal weights must be True
+ assert kwargs.pop('shared_weights', True) # shared weights must be false
+
+ irreps_in = Irreps(irreps_in)
+ irreps_out = Irreps(irreps_out)
+
+ instr_tp = []
+ weights1, weights2 = [], []
+ for i1, (mul1, ir1) in enumerate(irreps_in):
+ for i2 in range(i1, len(irreps_in)):
+ mul2, ir2 = irreps_in[i2]
+ for i_out, (mul_out, ir3) in enumerate(irreps_out):
+ if ir3 in ir1 * ir2:
+ weights1.append(nn.Parameter(torch.randn(mul1, mul_out)))
+ weights2.append(nn.Parameter(torch.randn(mul2, mul_out)))
+ instr_tp.append((i1, i2, i_out, 'uvw', True, 1.0))
+
+ self.tp = TensorProduct(irreps_in, irreps_in, irreps_out, instr_tp, internal_weights=False, shared_weights=True, **kwargs)
+
+ self.weights1 = nn.ParameterList(weights1)
+ self.weights2 = nn.ParameterList(weights2)
+
+ def forward(self, x):
+ weights = []
+ for weight1, weight2 in zip(self.weights1, self.weights2):
+ weight = weight1[:, None, :] * weight2[None, :, :]
+ weights.append(weight.view(-1))
+ weights = torch.cat(weights)
+ return self.tp(x, x, weights)
+
+@compile_mode("script")
+class SeparateWeightTensorProduct(nn.Module):
+ def __init__(self, irreps_in1: Union[str, o3.Irreps], irreps_in2: Union[str, o3.Irreps], irreps_out: Union[str, o3.Irreps], **kwargs):
+ '''z_i = W'_{ij}x_j W''_{ik}y_k'''
+ super(SeparateWeightTensorProduct, self).__init__()
+
+ assert not kwargs.pop('internal_weights', False) # internal weights must be True
+ assert kwargs.pop('shared_weights', True) # shared weights must be false
+
+ irreps_in1 = Irreps(irreps_in1)
+ irreps_in2 = Irreps(irreps_in2)
+ irreps_out = Irreps(irreps_out)
+ self.irreps_in1 = irreps_in1
+ self.irreps_in2 = irreps_in2
+ self.irreps_out = irreps_out
+
+ instr_tp = []
+ weights1, weights2 = [], []
+ for i1, (mul1, ir1) in enumerate(irreps_in1):
+ for i2, (mul2, ir2) in enumerate(irreps_in2):
+ for i_out, (mul_out, ir3) in enumerate(irreps_out):
+ if ir3 in ir1 * ir2:
+ weights1.append(nn.Parameter(torch.randn(mul1, mul_out)))
+ weights2.append(nn.Parameter(torch.randn(mul2, mul_out)))
+ instr_tp.append((i1, i2, i_out, 'uvw', True, 1.0))
+
+ self.tp = TensorProduct(irreps_in1, irreps_in2, irreps_out, instr_tp, internal_weights=False, shared_weights=True, **kwargs)
+
+ self.weights1 = nn.ParameterList(weights1)
+ self.weights2 = nn.ParameterList(weights2)
+
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor):
+ weights = []
+ for weight1, weight2 in zip(self.weights1, self.weights2):
+ weight = weight1[:, None, :] * weight2[None, :, :]
+ weights.append(weight.view(-1))
+ weights = torch.cat(weights)
+ return self.tp(x1, x2, weights)
+
+
+class SphericalBasis(nn.Module):
+ def __init__(self, target_irreps, rcutoff, eps=1e-7, dtype=torch.get_default_dtype()):
+ super().__init__()
+
+ target_irreps = Irreps(target_irreps)
+
+ self.sh = SphericalHarmonics(
+ irreps_out=target_irreps,
+ normalize=True,
+ normalization='component',
+ )
+
+ max_order = max(map(lambda x: x[1].l, target_irreps)) # maximum angular momentum l
+ max_freq = max(map(lambda x: x[0], target_irreps)) # maximum multiplicity
+
+ basis = bessel_basis(max_order + 1, max_freq)
+ lambdify_torch = {
+ # '+': torch.add,
+ # '-': torch.sub,
+ # '*': torch.mul,
+ # '/': torch.div,
+ # '**': torch.pow,
+ 'sin': torch.sin,
+ 'cos': torch.cos
+ }
+ x = sym.symbols('x')
+ funcs = []
+ for mul, ir in target_irreps:
+ for freq in range(mul):
+ funcs.append(sym.lambdify([x], basis[ir.l][freq], [lambdify_torch]))
+
+ self.bessel_funcs = funcs
+ self.multiplier = e3ElementWise(target_irreps)
+ self.dtype = dtype
+ self.cutoff = polynomial_cutoff
+ self.register_buffer('rcutoff', torch.Tensor([rcutoff]))
+ self.irreps_out = target_irreps
+ self.register_buffer('eps', torch.Tensor([eps]))
+
+ def forward(self, length, direction):
+ # direction should be in y, z, x order
+ sh = self.sh(direction).type(self.dtype)
+ sbf = torch.stack([f((length + self.eps) / self.rcutoff) for f in self.bessel_funcs], dim=-1)
+ return self.multiplier(sh, sbf) * self.cutoff(x=length, r_max=self.rcutoff, p=6).flatten()[:, None]
\ No newline at end of file
diff --git a/dptb/nn/embedding/identity.py b/dptb/nn/embedding/identity.py
new file mode 100644
index 00000000..16a2d076
--- /dev/null
+++ b/dptb/nn/embedding/identity.py
@@ -0,0 +1,24 @@
+import torch
+from typing import Optional, Tuple, Union
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+
+@Embedding.register("none")
+class Identity(torch.nn.Module):
+ def __init__(
+ self,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+ super(Identity, self).__init__(Identity, dtype, device)
+
+ def forward(self, data: AtomicDataDict) -> AtomicDataDict:
+ return data
+
+ @property
+ def out_edge_dim(self):
+ return 0
+
+ @property
+ def out_note_dim(self):
+ return 0
\ No newline at end of file
diff --git a/dptb/nn/embedding/mpnn.py b/dptb/nn/embedding/mpnn.py
new file mode 100644
index 00000000..0628f1b6
--- /dev/null
+++ b/dptb/nn/embedding/mpnn.py
@@ -0,0 +1,316 @@
+from torch_geometric.nn.conv import MessagePassing
+import torch
+from typing import Optional, Tuple, Union
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..base import ResNet, FFN
+from torch.nn import Linear
+import torch.nn as nn
+import torch.nn.functional as F
+from dptb.utils.constants import dtype_dict
+from ..type_encode.one_hot import OneHotAtomEncoding
+from ..cutoff import polynomial_cutoff
+from ..radial_basis import BesselBasis
+from torch_runstats.scatter import scatter
+
+def get_neuron_config(nl):
+ n = len(nl)
+ if n % 2 == 0:
+ d_out = nl[-1]
+ nl = nl[:-1]
+ config = []
+ for i in range(1,len(nl)-1, 2):
+ config.append({'in_features': nl[i-1], 'hidden_features': nl[i], 'out_features': nl[i+1]})
+
+ if n % 2 == 0:
+ config.append({'in_features': nl[-1], 'out_features': d_out})
+
+ return config
+
+@Embedding.register("mpnn")
+class MPNN(torch.nn.Module):
+ def __init__(
+ self,
+ r_max:Union[float, torch.Tensor],
+ p:Union[int, torch.LongTensor],
+ n_basis: Union[int, torch.LongTensor, None]=None,
+ n_node: Union[int, torch.LongTensor, None]=None,
+ n_edge: Union[int, torch.LongTensor, None]=None,
+ n_atom: int=1,
+ n_layer: int=1,
+ node_net: dict={},
+ edge_net: dict={},
+ if_exp: bool=False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu")):
+
+ super(MPNN, self).__init__()
+
+ self.n_node = n_node
+ self.n_edge = n_edge
+ if isinstance(r_max, float):
+ self.r_max = torch.tensor(r_max, dtype=dtype, device=device)
+ else:
+ self.r_max = r_max
+
+ self.p = p
+ self.layers = torch.nn.ModuleList([])
+ for _ in range(n_layer):
+ self.layers.append(
+ CGConvLayer(
+ r_max=self.r_max,
+ p=p,
+ n_edge=n_edge,
+ n_node=n_node,
+ node_net=node_net,
+ edge_net=edge_net,
+ dtype=dtype,
+ device=device,
+ if_exp=if_exp,
+ )
+ )
+
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+ self.node_emb = torch.nn.Linear(n_atom, n_node)
+ edge_net["config"] = get_neuron_config([2*n_node+n_basis]+edge_net["neurons"]+[n_edge])
+ self.edge_emb = ResNet(**edge_net, device=device, dtype=dtype)
+ self.bessel = BesselBasis(r_max=r_max, num_basis=n_basis, trainable=True)
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ data = self.onehot(data)
+ data = AtomicDataDict.with_env_vectors(data, with_lengths=True)
+ data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)
+
+ node_features = self.node_emb(data[AtomicDataDict.NODE_ATTRS_KEY])
+ env_features = self.edge_emb(torch.cat([node_features[data[AtomicDataDict.ENV_INDEX_KEY][0]], node_features[data[AtomicDataDict.ENV_INDEX_KEY][1]], self.bessel(data[AtomicDataDict.ENV_LENGTH_KEY])], dim=-1))
+ edge_features = self.edge_emb(torch.cat([node_features[data[AtomicDataDict.EDGE_INDEX_KEY][0]], node_features[data[AtomicDataDict.EDGE_INDEX_KEY][1]], self.bessel(data[AtomicDataDict.EDGE_LENGTH_KEY])], dim=-1))
+
+ for layer in self.layers:
+ node_features, env_features, edge_features = layer(
+ env_index=data[AtomicDataDict.ENV_INDEX_KEY],
+ edge_index=data[AtomicDataDict.EDGE_INDEX_KEY],
+ env_emb=env_features,
+ edge_emb=edge_features,
+ node_emb=node_features,
+ env_length=data[AtomicDataDict.ENV_LENGTH_KEY],
+ edge_length=data[AtomicDataDict.EDGE_LENGTH_KEY],
+ )
+ data[AtomicDataDict.NODE_FEATURES_KEY] = node_features
+
+ data[AtomicDataDict.EDGE_FEATURES_KEY] = edge_features
+
+ return data
+
+ @property
+ def out_edge_dim(self):
+ return self.n_edge
+
+ @property
+ def out_node_dim(self):
+ return self.n_node
+
+
+class MPNNLayer(MessagePassing):
+ def __init__(
+ self,
+ r_max:Union[float, torch.Tensor],
+ p:Union[int, torch.LongTensor],
+ n_edge: int,
+ n_node: int,
+ node_net: dict={},
+ edge_net: dict={},
+ aggr="mean",
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"), **kwargs):
+
+ super(MPNNLayer, self).__init__(aggr=aggr, **kwargs)
+
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ if isinstance(r_max, float):
+ self.r_max = torch.tensor(r_max, dtype=dtype, device=device)
+ else:
+ self.r_max = r_max
+
+ self.p = p
+
+ edge_net["config"] = get_neuron_config([2*n_node+n_edge]+edge_net["neurons"]+[n_edge])
+ self.mlp_edge = ResNet(**edge_net, device=device, dtype=dtype)
+ node_net["config"] = get_neuron_config([2*n_node+n_edge]+node_net["neurons"]+[n_node])
+ self.mlp_node = ResNet(**node_net, dtype=dtype, device=device)
+
+ self.node_layer_norm = torch.nn.LayerNorm(n_node, elementwise_affine=True)
+
+ self.device = device
+ self.dtype = dtype
+
+ def forward(self, edge_index, env_index, node_emb, env_emb, edge_emb):
+
+ z_ik = torch.cat([node_emb[env_index[0]], node_emb[env_index[1]], env_emb], dim=-1)
+ node_emb = node_emb + self.propagate(env_index, z_ik=z_ik)
+
+ env_emb = self.mlp_edge(torch.cat([node_emb[env_index[0]], env_emb, node_emb[env_index[1]]], dim=-1))
+ edge_emb = self.mlp_edge(torch.cat([node_emb[edge_index[0]], edge_emb, node_emb[edge_index[1]]], dim=-1))
+
+ return node_emb, env_emb, edge_emb
+
+ def message(self, z_ik):
+
+ return self.mlp_node(z_ik)
+
+ def update(self, aggr_out):
+ """_summary_
+
+ Parameters
+ ----------
+ aggr_out : The output of the aggregation layer, which is the mean of the message vectors as size [N, D, 3]
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+
+ aggr_out = aggr_out.reshape(aggr_out.shape[0], -1)
+ return self.node_layer_norm(aggr_out) # [N, D*D]
+
+class CGConvLayer(MessagePassing):
+ def __init__(
+ self,
+ r_max:Union[float, torch.Tensor],
+ p:Union[int, torch.LongTensor],
+ n_edge: int,
+ n_node: int,
+ aggr="add",
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ if_exp: bool=False,
+ **kwargs):
+
+ super(CGConvLayer, self).__init__(aggr=aggr, **kwargs)
+
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+ if isinstance(r_max, float):
+ self.r_max = torch.tensor(r_max, dtype=dtype, device=device)
+ else:
+ self.r_max = r_max
+
+ self.p = p
+ self.if_exp = if_exp
+
+ self.lin_edge_f = Linear(2*n_node+n_edge, n_edge, device=device, dtype=dtype)
+ self.lin_edge_s = Linear(2*n_node+n_edge, n_edge, device=device, dtype=dtype)
+ self.lin_node_f = Linear(2*n_node+n_edge, n_node, dtype=dtype, device=device)
+ self.lin_node_s = Linear(2*n_node+n_edge, n_node, dtype=dtype, device=device)
+
+ self.node_layer_norm = torch.nn.LayerNorm(n_node, elementwise_affine=True)
+
+ self.device = device
+ self.dtype = dtype
+
+ def forward(self, edge_index, env_index, node_emb, env_emb, edge_emb, env_length, edge_length):
+ z_ik = torch.cat([node_emb[env_index[0]], node_emb[env_index[1]], env_emb], dim=-1)
+ node_emb = node_emb + self.propagate(env_index, z_ik=z_ik, env_length=env_length)
+
+ env_feature_in = torch.cat([node_emb[env_index[0]], env_emb, node_emb[env_index[1]]], dim=-1)
+ env_emb = self.lin_edge_f(env_feature_in).sigmoid() * \
+ F.softplus(self.lin_edge_s(env_feature_in))
+ if self.if_exp:
+ sigma = 3
+ n = 2
+ env_emb = env_emb * torch.exp(-env_length ** n / sigma ** n / 2).view(-1, 1)
+
+ edge_feature_in = torch.cat([node_emb[edge_index[0]], edge_emb, node_emb[edge_index[1]]], dim=-1)
+ edge_emb = self.lin_edge_f(edge_feature_in).sigmoid() * \
+ F.softplus(self.lin_edge_s(edge_feature_in))
+ if self.if_exp:
+ sigma = 3
+ n = 2
+ edge_emb = edge_emb * torch.exp(-edge_length ** n / sigma ** n / 2).view(-1, 1)
+
+ return node_emb, env_emb, edge_emb
+
+
+ def message(self, z_ik, env_length) -> torch.Tensor:
+ out = self.lin_node_f(z_ik).sigmoid() * F.softplus(self.lin_node_s(z_ik))
+ if self.if_exp:
+ sigma = 3
+ n = 2
+ out = out * torch.exp(-env_length ** n / sigma ** n / 2).view(-1, 1)
+ return self.node_layer_norm(out)
+
+
+
+# class CGConv(MessagePassing):
+# def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0,
+# aggr: str = 'add', normalization: str = None,
+# bias: bool = True, if_exp: bool = False, **kwargs):
+# super(CGConv, self).__init__(aggr=aggr, flow="source_to_target", **kwargs)
+# self.channels = channels
+# self.dim = dim
+# self.normalization = normalization
+# self.if_exp = if_exp
+
+# if isinstance(channels, int):
+# channels = (channels, channels)
+
+# self.lin_f = nn.Linear(sum(channels) + dim, channels[1], bias=bias)
+# self.lin_s = nn.Linear(sum(channels) + dim, channels[1], bias=bias)
+# if self.normalization == 'BatchNorm':
+# self.bn = nn.BatchNorm1d(channels[1], track_running_stats=True)
+# elif self.normalization == 'LayerNorm':
+# self.ln = LayerNorm(channels[1])
+# elif self.normalization == 'PairNorm':
+# self.pn = PairNorm(channels[1])
+# elif self.normalization == 'InstanceNorm':
+# self.instance_norm = InstanceNorm(channels[1])
+# elif self.normalization is None:
+# pass
+# else:
+# raise ValueError('Unknown normalization function: {}'.format(normalization))
+
+# self.reset_parameters()
+
+# def reset_parameters(self):
+# self.lin_f.reset_parameters()
+# self.lin_s.reset_parameters()
+# if self.normalization == 'BatchNorm':
+# self.bn.reset_parameters()
+
+# def forward(self, x: Union[torch.Tensor, PairTensor], edge_index: Adj,
+# edge_attr: OptTensor, env_index, env_attr, batch, distance, size: Size = None) -> torch.Tensor:
+# """"""
+# if isinstance(x, torch.Tensor):
+# x: PairTensor = (x, x)
+
+# # propagate_type: (x: PairTensor, edge_attr: OptTensor)
+# out = self.propagate(edge_index, x=x, edge_attr=edge_attr, distance=distance, size=size)
+# if self.normalization == 'BatchNorm':
+# out = self.bn(out)
+# elif self.normalization == 'LayerNorm':
+# out = self.ln(out, batch)
+# elif self.normalization == 'PairNorm':
+# out = self.pn(out, batch)
+# elif self.normalization == 'InstanceNorm':
+# out = self.instance_norm(out, batch)
+# out += x[1]
+# return out
+
+# def message(self, x_i, x_j, edge_attr: OptTensor, distance) -> torch.Tensor:
+# z = torch.cat([x_i, x_j, edge_attr], dim=-1)
+# out = self.lin_f(z).sigmoid() * F.softplus(self.lin_s(z))
+# if self.if_exp:
+# sigma = 3
+# n = 2
+# out = out * torch.exp(-distance ** n / sigma ** n / 2).view(-1, 1)
+# return out
+
+# def __repr__(self):
+# return '{}({}, dim={})'.format(self.__class__.__name__, self.channels, self.dim)
\ No newline at end of file
diff --git a/dptb/nn/embedding/se2.py b/dptb/nn/embedding/se2.py
new file mode 100644
index 00000000..6616bb2c
--- /dev/null
+++ b/dptb/nn/embedding/se2.py
@@ -0,0 +1,199 @@
+from torch_geometric.nn import MessagePassing
+from torch_geometric.nn import Aggregation
+import torch
+from typing import Optional, Tuple, Union
+from dptb.data import AtomicDataDict
+from dptb.nn.embedding.emb import Embedding
+from ..base import ResNet
+from dptb.utils.constants import dtype_dict
+from ..type_encode.one_hot import OneHotAtomEncoding
+
+def get_neuron_config(nl):
+ n = len(nl)
+ if n % 2 == 0:
+ d_out = nl[-1]
+ nl = nl[:-1]
+ config = []
+ for i in range(1,len(nl)-1, 2):
+ config.append({'in_features': nl[i-1], 'hidden_features': nl[i], 'out_features': nl[i+1]})
+
+ if n % 2 == 0:
+ config.append({'in_features': nl[-1], 'out_features': d_out})
+
+ return config
+
+@Embedding.register("se2")
+class SE2Descriptor(torch.nn.Module):
+ def __init__(
+ self,
+ rs: Union[float, torch.Tensor],
+ rc:Union[float, torch.Tensor],
+ n_axis: Union[int, torch.LongTensor, None]=None,
+ n_atom: int=1,
+ radial_net: dict={},
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ) -> None:
+ """
+ a demo input
+ se2_config = {
+ "rs": 3.0,
+ "rc": 4.0,
+ "n_axis": 4,
+ "n_atom": 2,
+ "radial_embedding": {
+ "neurons": [10,20,30],
+ "activation": "tanh",
+ "if_batch_normalized": False
+ },
+ "dtype": "float32",
+ "device": "cpu"
+ }
+ """
+
+ super(SE2Descriptor, self).__init__()
+ self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)
+ self.descriptor = _SE2Descriptor(rs=rs, rc=rc, n_atom=n_atom, radial_net=radial_net, n_axis=n_axis, dtype=dtype, device=device)
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ """_summary_
+
+ Parameters
+ ----------
+ data : _type_
+ _description_
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+ data = self.onehot(data)
+ data = AtomicDataDict.with_env_vectors(data, with_lengths=True)
+ data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)
+
+ data[AtomicDataDict.NODE_FEATURES_KEY], data[AtomicDataDict.EDGE_FEATURES_KEY] = self.descriptor(
+ data[AtomicDataDict.ENV_VECTORS_KEY],
+ data[AtomicDataDict.NODE_ATTRS_KEY],
+ data[AtomicDataDict.ENV_INDEX_KEY],
+ data[AtomicDataDict.EDGE_INDEX_KEY],
+ data[AtomicDataDict.EDGE_LENGTH_KEY],
+ )
+
+ return data
+
+ @property
+ def out_edge_dim(self):
+ return self.descriptor.n_out + 1
+
+ @property
+ def out_node_dim(self):
+ return self.descriptor.n_out
+
+
+
+
+class SE2Aggregation(Aggregation):
+ def forward(self, x: torch.Tensor, index: torch.LongTensor, **kwargs):
+ """_summary_
+
+ Parameters
+ ----------
+ x : tensor of size (N, d), where d dimension looks like (emb(s(r)), \hat{x}, \hat{y}, \hat{z})
+ The is the embedding of the env_vectors
+ index : _type_
+ _description_
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+ direct_vec = x[:, -3:]
+ x = x[:,:-3].unsqueeze(-1) * direct_vec.unsqueeze(1) # [N_env, D, 3]
+ return self.reduce(x, index, reduce="mean", dim=0) # [N_atom, D, 3] following the orders of atom index.
+
+
+class _SE2Descriptor(MessagePassing):
+ def __init__(
+ self,
+ rs: Union[float, torch.Tensor],
+ rc:Union[float, torch.Tensor],
+ n_axis: Union[int, torch.LongTensor, None]=None,
+ aggr: SE2Aggregation=SE2Aggregation(),
+ radial_net: dict={},
+ n_atom: int=1,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"), **kwargs):
+
+ super(_SE2Descriptor, self).__init__(aggr=aggr, **kwargs)
+
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+
+
+ radial_net["config"] = get_neuron_config([2*n_atom+1]+radial_net["neurons"])
+ self.embedding_net = ResNet(**radial_net, device=device, dtype=dtype)
+ if isinstance(rs, float):
+ self.rs = torch.tensor(rs, dtype=dtype, device=device)
+ else:
+ self.rs = rs
+ if isinstance(rc, float):
+ self.rc = torch.tensor(rc, dtype=dtype, device=device)
+ else:
+ self.rc = rc
+
+ assert len(self.rc.flatten()) == 1 and len(self.rs.flatten()) == 1
+ assert self.rs < self.rc
+ self.n_axis = n_axis
+ self.device = device
+ self.dtype = dtype
+ if n_axis == None:
+ self.n_axis = radial_net["neurons"][-1]
+ self.n_out = self.n_axis * radial_net["neurons"][-1]
+
+ def forward(self, env_vectors, atom_attr, env_index, edge_index, edge_length):
+ n_env = env_index.shape[1]
+ env_attr = atom_attr[env_index].transpose(1,0).reshape(n_env,-1)
+ out_node = self.propagate(env_index, env_vectors=env_vectors, env_attr=env_attr) # [N_atom, D, 3]
+ out_edge = self.edge_updater(edge_index, node_descriptor=out_node, edge_length=edge_length) # [N_edge, D*D]
+
+ return out_node, out_edge
+
+ def message(self, env_vectors, env_attr):
+ rij = env_vectors.norm(dim=-1, keepdim=True)
+ snorm = self.smooth(rij, self.rs, self.rc)
+ env_vectors = snorm * env_vectors / rij
+ return torch.cat([self.embedding_net(torch.cat([snorm, env_attr], dim=-1)), env_vectors], dim=-1) # [N_env, D_emb + 3]
+
+ def update(self, aggr_out):
+ """_summary_
+
+ Parameters
+ ----------
+ aggr_out : The output of the aggregation layer, which is the mean of the message vectors as size [N, D, 3]
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+ out = torch.bmm(aggr_out, aggr_out.transpose(1, 2))[:,:,:self.n_axis].flatten(start_dim=1, end_dim=2)
+ out = out - out.mean(1, keepdim=True)
+ out = out / out.norm(dim=1, keepdim=True)
+ return out # [N, D*D]
+
+ def edge_update(self, edge_index, node_descriptor, edge_length):
+ return torch.cat([node_descriptor[edge_index[0]] + node_descriptor[edge_index[1]], 1/edge_length.reshape(-1,1)], dim=-1) # [N_edge, D*D]
+
+ def smooth(self, r: torch.Tensor, rs: torch.Tensor, rc: torch.Tensor):
+ r_ = torch.zeros_like(r)
+ r_[r AtomicDataDict.Type:
+ data = self.h2k(data)
+ if self.overlap:
+ data = self.s2k(data)
+ chklowt = torch.linalg.cholesky(data[self.s_out_field])
+ chklowtinv = torch.linalg.inv(chklowt)
+ Heff = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
+ else:
+ Heff = data[self.h_out_field]
+
+ data[self.out_field] = torch.linalg.eigvalsh(Heff)
+
+ return data
diff --git a/dptb/nn/graph_mixin.py b/dptb/nn/graph_mixin.py
new file mode 100644
index 00000000..253cba26
--- /dev/null
+++ b/dptb/nn/graph_mixin.py
@@ -0,0 +1,119 @@
+import random
+from typing import Dict, Tuple, Callable, Any, Sequence, Union, Mapping, Optional
+from collections import OrderedDict
+
+import torch
+
+from e3nn import o3
+
+from dptb.data import AtomicDataDict
+from dptb.utils import instantiate
+
+
+class GraphModuleMixin:
+ r"""Mixin parent class for ``torch.nn.Module``s that act on and return ``AtomicDataDict.Type`` graph data.
+
+ All such classes should call ``_init_irreps`` in their ``__init__`` functions with information on the data fields they expect, require, and produce, as well as their corresponding irreps.
+ """
+
+ def _init_irreps(
+ self,
+ irreps_in: Dict[str, Any] = {},
+ my_irreps_in: Dict[str, Any] = {},
+ required_irreps_in: Sequence[str] = [],
+ irreps_out: Dict[str, Any] = {},
+ ):
+ """Setup the expected data fields and their irreps for this graph module.
+
+ ``None`` is a valid irreps in the context for anything that is invariant but not well described by an ``e3nn.o3.Irreps``. An example are edge indexes in a graph, which are invariant but are integers, not ``0e`` scalars.
+
+ Args:
+ irreps_in (dict): maps names of all input fields from previous modules or
+ data to their corresponding irreps
+ my_irreps_in (dict): maps names of fields to the irreps they must have for
+ this graph module. Will be checked for consistancy with ``irreps_in``
+ required_irreps_in: sequence of names of fields that must be present in
+ ``irreps_in``, but that can have any irreps.
+ irreps_out (dict): mapping names of fields that are modified/output by
+ this graph module to their irreps.
+ """
+ # Coerce
+ irreps_in = {} if irreps_in is None else irreps_in
+ irreps_in = AtomicDataDict._fix_irreps_dict(irreps_in)
+ # positions are *always* 1o, and always present
+ if AtomicDataDict.POSITIONS_KEY in irreps_in:
+ if irreps_in[AtomicDataDict.POSITIONS_KEY] != o3.Irreps("1x1o"):
+ raise ValueError(
+ f"Positions must have irreps 1o, got instead `{irreps_in[AtomicDataDict.POSITIONS_KEY]}`"
+ )
+ irreps_in[AtomicDataDict.POSITIONS_KEY] = o3.Irreps("1o")
+ # edges are also always present
+ if AtomicDataDict.EDGE_INDEX_KEY in irreps_in:
+ if irreps_in[AtomicDataDict.EDGE_INDEX_KEY] is not None:
+ raise ValueError(
+ f"Edge indexes must have irreps None, got instead `{irreps_in[AtomicDataDict.EDGE_INDEX_KEY]}`"
+ )
+ irreps_in[AtomicDataDict.EDGE_INDEX_KEY] = None
+
+ my_irreps_in = AtomicDataDict._fix_irreps_dict(my_irreps_in)
+
+ irreps_out = AtomicDataDict._fix_irreps_dict(irreps_out)
+ # Confirm compatibility:
+ # with my_irreps_in
+ for k in my_irreps_in:
+ if k in irreps_in and irreps_in[k] != my_irreps_in[k]:
+ raise ValueError(
+ f"The given input irreps {irreps_in[k]} for field '{k}' is incompatible with this configuration {type(self)}; should have been {my_irreps_in[k]}"
+ )
+ # with required_irreps_in
+ for k in required_irreps_in:
+ if k not in irreps_in:
+ raise ValueError(
+ f"This {type(self)} requires field '{k}' to be in irreps_in"
+ )
+ # Save stuff
+ self.irreps_in = irreps_in
+ # The output irreps of any graph module are whatever inputs it has, overwritten with whatever outputs it has.
+ new_out = irreps_in.copy()
+ new_out.update(irreps_out)
+ self.irreps_out = new_out
+
+ def _add_independent_irreps(self, irreps: Dict[str, Any]):
+ """
+ Insert some independent irreps that need to be exposed to the self.irreps_in and self.irreps_out.
+ The terms that have already appeared in the irreps_in will be removed.
+
+ Args:
+ irreps (dict): maps names of all new fields
+ """
+
+ irreps = {
+ key: irrep for key, irrep in irreps.items() if key not in self.irreps_in
+ }
+ irreps_in = AtomicDataDict._fix_irreps_dict(irreps)
+ irreps_out = AtomicDataDict._fix_irreps_dict(
+ {key: irrep for key, irrep in irreps.items() if key not in self.irreps_out}
+ )
+ self.irreps_in.update(irreps_in)
+ self.irreps_out.update(irreps_out)
+
+ def _make_tracing_inputs(self, n):
+ # We impliment this to be able to trace graph modules
+ out = []
+ for _ in range(n):
+ batch = random.randint(1, 4)
+ # TODO: handle None case
+ # TODO: do only required inputs
+ # TODO: dummy input if empty?
+ out.append(
+ {
+ "forward": (
+ {
+ k: i.randn(batch, -1)
+ for k, i in self.irreps_in.items()
+ if i is not None
+ },
+ )
+ }
+ )
+ return out
\ No newline at end of file
diff --git a/dptb/nn/hamiltonian.py b/dptb/nn/hamiltonian.py
new file mode 100644
index 00000000..8693a22b
--- /dev/null
+++ b/dptb/nn/hamiltonian.py
@@ -0,0 +1,394 @@
+"""
+This file refactor the SK and E3 Rotation in dptb/hamiltonian/transform_se3.py], it will take input of AtomicDataDict.Type
+perform rotation from irreducible matrix element / sk parameters in EDGE/NODE FEATURE, and output the atomwise/ pairwise hamiltonian
+as the new EDGE/NODE FEATURE. The rotation should also be a GNN module and speed uptable by JIT. The HR2HK should also be included here.
+The indexmapping should ne passed here.
+"""
+
+import torch
+from e3nn.o3 import wigner_3j, Irrep, xyz_to_angles, Irrep
+from dptb.utils.constants import h_all_types, anglrMId
+from typing import Tuple, Union, Dict
+from dptb.data.transforms import OrbitalMapper
+from dptb.data import AtomicDataDict
+import re
+from torch_runstats.scatter import scatter
+
+#TODO: 1. jit acceleration 2. GPU support 3. rotate AB and BA bond together.
+
+# The `E3Hamiltonian` class is a PyTorch module that represents a Hamiltonian for a system with a
+# given basis and can perform forward computations on input data.
+
+class E3Hamiltonian(torch.nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ decompose: bool = False,
+ edge_field: str = AtomicDataDict.EDGE_FEATURES_KEY,
+ node_field: str = AtomicDataDict.NODE_FEATURES_KEY,
+ overlap: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ) -> None:
+
+ super(E3Hamiltonian, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.overlap = overlap
+ self.dtype = dtype
+ self.device = device
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb")
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.cgbasis = {}
+ self.decompose = decompose
+ self.edge_field = edge_field
+ self.node_field = node_field
+
+ # initialize the CG basis
+ self.idp.get_orbpairtype_maps()
+ orbpairtypes = self.idp.orbpairtype_maps.keys()
+ for orbpair in orbpairtypes:
+ self._initialize_CG_basis(orbpair)
+
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ """
+ The forward function takes in atomic data and performs computations on the edge and node features
+ based on the decompose flag. It performs the following operations:
+ decompose = True:
+ - the function will read the EDGE and NODE features and take them as hamiltonian blocks, the
+ block will be decomposed into reduced matrix element that is irrelevant to the direction.
+ decompose = False:
+ - the function will read the EDGE and NODE features and take them as reduced matrix element, the
+ function will transform the reduced matrix element into hamiltonian blocks with directional dependence.
+
+ :param data: The `data` parameter is a dictionary that contains atomic data. It has the following
+ keys:
+ :type data: AtomicDataDict.Type
+ :return: the updated `data` dictionary.
+ """
+
+ assert data[self.edge_field].shape[1] == self.idp.reduced_matrix_element
+ if not self.overlap:
+ assert data[self.node_field].shape[1] == self.idp.reduced_matrix_element
+
+ n_edge = data[AtomicDataDict.EDGE_INDEX_KEY].shape[1]
+ n_node = data[AtomicDataDict.NODE_FEATURES_KEY].shape[0]
+
+ data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)
+
+ if not self.decompose:
+ # The EDGE_FEATURES_KEY and NODE_FAETURE_KEY are the reduced matrix elements
+
+ # compute hopping blocks
+ for opairtype in self.idp.orbpairtype_maps.keys():
+ # currently, "a-b" and "b-a" orbital pair are computed seperately, it is able to combined further
+ # for better performance
+ l1, l2 = anglrMId[opairtype[0]], anglrMId[opairtype[2]]
+ n_rme = (2*l1+1) * (2*l2+1) # number of reduced matrix element
+ rme = data[self.edge_field][:, self.idp.orbpairtype_maps[opairtype]]
+ rme = rme.reshape(n_edge, -1, n_rme)
+ rme = rme.transpose(1,2) # shape (N, n_rme, n_pair)
+
+ HR = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
+ rme[:,None, None, :, :], dim=-2) # shape (N, 2l1+1, 2l2+1, n_pair)
+
+ # rotation
+ # angle = xyz_to_angles(data[AtomicDataDict.EDGE_VECTORS_KEY][:,[1,2,0]]) # (tensor(N), tensor(N))
+ # rot_mat_L = Irrep(int(l1), 1).D_from_angles(angle[0], angle[1], torch.tensor(0., dtype=self.dtype, device=self.device)) # tensor(N, 2l1+1, 2l1+1)
+ # rot_mat_R = Irrep(int(l2), 1).D_from_angles(angle[0], angle[1], torch.tensor(0., dtype=self.dtype, device=self.device)) # tensor(N, 2l2+1, 2l2+1)
+ # HR = torch.einsum("nlm, nmoq, nko -> nqlk", rot_mat_L, H_z, rot_mat_R).reshape(n_edge, -1) # shape (N, n_pair * n_rme)
+ HR = HR.permute(0,3,1,2).reshape(n_edge, -1)
+ data[self.edge_field][:, self.idp.orbpairtype_maps[opairtype]] = HR
+
+ # compute onsite blocks
+ if not self.overlap:
+ for opairtype in self.idp.orbpairtype_maps.keys():
+ # currently, "a-b" and "b-a" orbital pair are computed seperately, it is able to combined further
+ # for better performance
+ l1, l2 = anglrMId[opairtype[0]], anglrMId[opairtype[2]]
+
+ n_rme = (2*l1+1) * (2*l2+1) # number of reduced matrix element
+ rme = data[self.node_field][:, self.idp.orbpairtype_maps[opairtype]]
+ rme = rme.reshape(n_node, -1, n_rme)
+ rme = rme.transpose(1,2) # shape (N, n_rme, n_pair)
+
+ HR = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
+ rme[:,None, None, :, :], dim=-2) # shape (N, 2l1+1, 2l2+1, n_pair)
+ HR = HR.permute(0,3,1,2).reshape(n_node, -1)
+
+ # the onsite block does not have rotation
+ data[self.node_field][:, self.idp.orbpairtype_maps[opairtype]] = HR
+
+ else:
+ for opairtype in self.idp.orbpairtype_maps.keys():
+ l1, l2 = anglrMId[opairtype[0]], anglrMId[opairtype[2]]
+ nL, nR = 2*l1+1, 2*l2+1
+ HR = data[self.edge_field][:, self.idp.orbpairtype_maps[opairtype]]
+ HR = HR.reshape(n_edge, -1, nL, nR) # shape (N, n_pair, nL, nR)
+
+ # rotation
+ # angle = xyz_to_angles(data[AtomicDataDict.EDGE_VECTORS_KEY][:,[1,2,0]]) # (tensor(N), tensor(N))
+ # rot_mat_L = Irrep(int(l1), 1).D_from_angles(angle[0], angle[1], torch.tensor(0., dtype=self.dtype, device=self.device)) # tensor(N, 2l1+1, 2l1+1)
+ # rot_mat_R = Irrep(int(l2), 1).D_from_angles(angle[0], angle[1], torch.tensor(0., dtype=self.dtype, device=self.device)) # tensor(N, 2l2+1, 2l2+1)
+ # H_z = torch.einsum("nml, nqmo, nok -> nlkq", rot_mat_L, HR, rot_mat_R) # shape (N, nL, nR, n_pair)
+
+ HR = HR.permute(0,2,3,1) # shape (N, nL, nR, n_pair)
+ rme = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
+ HR[:,:,:,None,:], dim=(1,2)) # shape (N, n_rme, n_pair)
+ rme = rme.transpose(1,2).reshape(n_edge, -1)
+
+ data[self.edge_field][:, self.idp.orbpairtype_maps[opairtype]] = rme
+
+ if not self.overlap:
+ for opairtype in self.idp.orbpairtype_maps.keys():
+ # currently, "a-b" and "b-a" orbital pair are computed seperately, it is able to combined further
+ # for better performance
+ l1, l2 = anglrMId[opairtype[0]], anglrMId[opairtype[2]]
+ nL, nR = 2*l1+1, 2*l2+1 # number of reduced matrix element
+ HR = data[self.node_field][:, self.idp.orbpairtype_maps[opairtype]]
+ HR = HR.reshape(n_node, -1, nL, nR).permute(0,2,3,1)# shape (N, nL, nR, n_pair)
+
+ rme = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
+ HR[:,:,:,None,:], dim=(1,2)) # shape (N, n_rme, n_pair)
+ rme = rme.transpose(1,2).reshape(n_node, -1)
+
+ # the onsite block doesnot have rotation
+ data[self.node_field][:, self.idp.orbpairtype_maps[opairtype]] = rme
+
+ return data
+
+ def _initialize_CG_basis(self, pairtype: str):
+ """
+ The function initializes a Clebsch-Gordan basis for a given pair type.
+
+ :param pairtype: The parameter "pairtype" is a string that represents a pair of angular momentum
+ quantum numbers. It is expected to have a length of 3, where the first and third characters
+ represent the angular momentum quantum numbers of two particles, and the second character
+ represents the type of interaction between the particles
+ :type pairtype: str
+ :return: the CG basis, which is a tensor containing the Clebsch-Gordan coefficients for the given
+ pairtype.
+ """
+ self.cgbasis.setdefault(pairtype, None)
+
+ l1, l2 = anglrMId[pairtype[0]], anglrMId[pairtype[2]]
+
+ cg = []
+ for l_ird in range(abs(l2-l1), l2+l1+1):
+ cg.append(wigner_3j(int(l1), int(l2), int(l_ird), dtype=self.dtype, device=self.device) * (2*l_ird+1)**0.5)
+
+ cg = torch.cat(cg, dim=-1)
+ self.cgbasis[pairtype] = cg
+
+ return cg
+
+
+class SKHamiltonian(torch.nn.Module):
+ # transform SK parameters to SK hamiltonian with E3 CG basis, strain is included.
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp_sk: Union[OrbitalMapper, None]=None,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ edge_field: str = AtomicDataDict.EDGE_FEATURES_KEY,
+ node_field: str = AtomicDataDict.NODE_FEATURES_KEY,
+ onsite: bool = False,
+ strain: bool = False,
+ **kwargs,
+ ) -> None:
+ super(SKHamiltonian, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.dtype = dtype
+ self.device = device
+ self.onsite = onsite
+
+ if basis is not None:
+ self.idp_sk = OrbitalMapper(basis, method="sktb", device=device)
+ if idp_sk is not None:
+ assert idp_sk.basis == self.idp_sk.basis, "The basis of idp and basis should be the same."
+ else:
+ assert idp_sk is not None, "Either basis or idp should be provided."
+ self.idp_sk = idp_sk
+ # initilize a e3 indexmapping to help putting the orbital wise blocks into atom-pair wise format
+ self.idp = OrbitalMapper(self.idp_sk.basis, method="e3tb", device=device)
+ self.basis = self.idp.basis
+ self.cgbasis = {}
+ self.strain = strain
+ self.edge_field = edge_field
+ self.node_field = node_field
+
+ self.idp_sk.get_orbpair_maps()
+ self.idp_sk.get_skonsite_maps()
+ self.idp.get_orbpair_maps()
+
+ pairtypes = self.idp_sk.orbpairtype_maps.keys()
+ for pairtype in pairtypes:
+ self._initialize_CG_basis(pairtype)
+
+ self.sk2irs = {
+ 's-s': torch.tensor([[1.]], dtype=self.dtype, device=self.device),
+ 's-p': torch.tensor([[1.]], dtype=self.dtype, device=self.device),
+ 's-d': torch.tensor([[1.]], dtype=self.dtype, device=self.device),
+ 'p-s': torch.tensor([[1.]], dtype=self.dtype, device=self.device),
+ 'p-p': torch.tensor([
+ [3**0.5/3,2/3*3**0.5],[6**0.5/3,-6**0.5/3]
+ ], dtype=self.dtype, device=self.device
+ ),
+ 'p-d':torch.tensor([
+ [(2/5)**0.5,(6/5)**0.5],[(3/5)**0.5,-2/5**0.5]
+ ], dtype=self.dtype, device=self.device
+ ),
+ 'd-s':torch.tensor([[1.]], dtype=self.dtype, device=self.device),
+ 'd-p':torch.tensor([
+ [(2/5)**0.5,(6/5)**0.5],
+ [(3/5)**0.5,-2/5**0.5]
+ ], dtype=self.dtype, device=self.device
+ ),
+ 'd-d':torch.tensor([
+ [5**0.5/5, 2*5**0.5/5, 2*5**0.5/5],
+ [2*(1/14)**0.5,2*(1/14)**0.5,-4*(1/14)**0.5],
+ [3*(2/35)**0.5,-4*(2/35)**0.5,(2/35)**0.5]
+ ], dtype=self.dtype, device=self.device
+ )
+ }
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ # transform sk parameters to irreducible matrix element
+
+ assert data[self.edge_field].shape[1] == self.idp_sk.reduced_matrix_element
+ if self.onsite:
+ assert data[self.node_field].shape[1] == self.idp_sk.n_onsite_Es
+ n_node = data[self.node_field].shape[0]
+
+ n_edge = data[self.edge_field].shape[0]
+
+
+ edge_features = data[self.edge_field].clone()
+ data[self.edge_field] = torch.zeros((n_edge, self.idp.reduced_matrix_element), dtype=self.dtype, device=self.device)
+
+ # for hopping blocks
+ for opairtype in self.idp_sk.orbpairtype_maps.keys():
+ l1, l2 = anglrMId[opairtype[0]], anglrMId[opairtype[2]]
+ n_skp = min(l1, l2)+1 # number of reduced matrix element
+ skparam = edge_features[:, self.idp_sk.orbpairtype_maps[opairtype]].reshape(n_edge, -1, n_skp)
+ rme = skparam @ self.sk2irs[opairtype].T # shape (N, n_pair, n_rme)
+ rme = rme.transpose(1,2) # shape (N, n_rme, n_pair)
+
+ H_z = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
+ rme[:,None, None, :, :], dim=-2) # shape (N, 2l1+1, 2l2+1, n_pair)
+
+ # rotation
+ # when get the angle, the xyz vector should be transformed to yzx.
+ angle = xyz_to_angles(data[AtomicDataDict.EDGE_VECTORS_KEY][:,[1,2,0]]) # (tensor(N), tensor(N))
+ # The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here.
+ rot_mat_L = Irrep(int(l1), 1).D_from_angles(angle[0].cpu(), angle[1].cpu(), torch.tensor(0., dtype=self.dtype)).to(self.device) # tensor(N, 2l1+1, 2l1+1)
+ rot_mat_R = Irrep(int(l2), 1).D_from_angles(angle[0].cpu(), angle[1].cpu(), torch.tensor(0., dtype=self.dtype)).to(self.device) # tensor(N, 2l2+1, 2l2+1)
+
+ # Here The string to control einsum is important, the order of the index should be the same as the order of the tensor
+ # H_z = torch.einsum("nlm, nmoq, nko -> nqlk", rot_mat_L, H_z, rot_mat_R) # shape (N, n_pair, 2l1+1, 2l2+1)
+ HR = torch.einsum("nlm, nmoq, nko -> nqlk", rot_mat_L, H_z, rot_mat_R).reshape(n_edge, -1) # shape (N, n_pair * 2l2+1 * 2l2+1)
+
+ if l1 < l2:
+ HR = HR * (-1)**(l1+l2)
+
+ data[self.edge_field][:, self.idp.orbpairtype_maps[opairtype]] = HR
+
+ # compute onsite blocks
+ if self.onsite:
+ node_feature = data[self.node_field].clone()
+ data[self.node_field] = torch.zeros(n_node, self.idp.reduced_matrix_element, dtype=self.dtype, device=self.device)
+
+ for otype in self.idp_sk.skonsite_maps.keys():
+ # currently, "a-b" and "b-a" orbital pair are computed seperately, it is able to combined further
+ # for better performance
+
+ l = anglrMId[re.findall(r"[a-z]", otype)[0]]
+
+ skparam = node_feature[:, self.idp_sk.skonsite_maps[otype]].reshape(n_node, -1, 1)
+ HR = torch.eye(2*l+1, dtype=self.dtype, device=self.device)[None, None, :, :] * skparam[:,:, None, :] # shape (N, n_pair, 2l1+1, 2l2+1)
+ # the onsite block doesnot have rotation
+
+ data[self.node_field][:, self.idp.orbpair_maps[otype+"-"+otype]] = HR.reshape(n_node, -1)
+
+ # compute if strain effect is included
+ # this is a little wired operation, since it acting on somekind of a edge(strain env) feature, and summed up to return a node feature.
+ if self.strain:
+ n_onsitenv = len(data[AtomicDataDict.ONSITENV_FEATURES_KEY])
+ for opairtype in self.idp.orbpairtype_maps.keys(): # save all env direction and pair direction like sp and ps, but only get sp
+ l1, l2 = anglrMId[opairtype[0]], anglrMId[opairtype[2]]
+ # opairtype = opair[1]+"-"+opair[4]
+ n_skp = min(l1, l2)+1 # number of reduced matrix element
+ skparam = data[AtomicDataDict.ONSITENV_FEATURES_KEY][:, self.idp_sk.orbpairtype_maps[opairtype]].reshape(n_onsitenv, -1, n_skp)
+ rme = skparam @ self.sk2irs[opairtype].T # shape (N, n_pair, n_rme)
+ rme = rme.transpose(1,2) # shape (N, n_rme, n_pair)
+
+ H_z = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
+ rme[:,None, None, :, :], dim=-2) # shape (N, 2l1+1, 2l2+1, n_pair)
+
+ angle = xyz_to_angles(data[AtomicDataDict.ONSITENV_VECTORS_KEY][:,[1,2,0]]) # (tensor(N), tensor(N))
+ rot_mat_L = Irrep(int(l1), 1).D_from_angles(angle[0].cpu(), angle[1].cpu(), torch.tensor(0., dtype=self.dtype)).to(self.device) # tensor(N, 2l1+1, 2l1+1)
+ rot_mat_R = Irrep(int(l2), 1).D_from_angles(angle[0].cpu(), angle[1].cpu(), torch.tensor(0., dtype=self.dtype)).to(self.device) # tensor(N, 2l2+1, 2l2+1)
+
+ HR = torch.einsum("nlm, nmoq, nko -> nqlk", rot_mat_L, H_z, rot_mat_R) # shape (N, n_pair, 2l1+1, 2l2+1)
+
+ HR = scatter(src=HR, index=data[AtomicDataDict.ONSITENV_INDEX_KEY][0], dim=0, reduce="sum") # shape (n_node, n_pair, 2l1+1, 2l2+1)
+ # A-B o1-o2 (A-B o2-o1)= (B-A o1-o2)
+
+ data[self.node_field][:, self.idp.orbpairtype_maps[opairtype]] += HR.flatten(1, len(HR.shape)-1) # the index type [node/pair] should align with the index of for loop
+
+ return data
+
+ def _initialize_CG_basis(self, pairtype: str):
+ """
+ The function initializes a Clebsch-Gordan basis for a given pair type.
+
+ :param pairtype: The parameter "pairtype" is a string that represents a pair of angular momentum
+ quantum numbers. It is expected to have a length of 3, where the first and third characters
+ represent the angular momentum quantum numbers of two particles, and the second character
+ represents the type of interaction between the particles
+ :type pairtype: str
+ :return: the CG basis, which is a tensor containing the Clebsch-Gordan coefficients for the given
+ pairtype.
+ """
+ self.cgbasis.setdefault(pairtype, None)
+
+ irs_index = {
+ 's-s': [0],
+ 's-p': [1],
+ 's-d': [2],
+ 'p-s': [1],
+ 'p-p': [0,6],
+ 'p-d': [1,11],
+ 'd-s': [2],
+ 'd-p': [1,11],
+ 'd-d': [0,6,20]
+ }
+
+ l1, l2 = anglrMId[pairtype[0]], anglrMId[pairtype[2]]
+
+ cg = []
+ for l_ird in range(abs(l2-l1), l2+l1+1):
+ cg.append(wigner_3j(int(l1), int(l2), int(l_ird), dtype=self.dtype, device=self.device) * (2*l_ird+1)**0.5)
+
+ cg = torch.cat(cg, dim=-1)[:,:,irs_index[pairtype]]
+ self.cgbasis[pairtype] = cg
+
+ return cg
\ No newline at end of file
diff --git a/dptb/nn/hr2hk.py b/dptb/nn/hr2hk.py
new file mode 100644
index 00000000..3dfb7ad3
--- /dev/null
+++ b/dptb/nn/hr2hk.py
@@ -0,0 +1,120 @@
+import torch
+from dptb.utils.constants import h_all_types, anglrMId, atomic_num_dict, atomic_num_dict_r
+from typing import Tuple, Union, Dict
+from dptb.data.transforms import OrbitalMapper
+from dptb.data import AtomicDataDict
+import re
+
+class HR2HK(torch.nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ edge_field: str = AtomicDataDict.EDGE_FEATURES_KEY,
+ node_field: str = AtomicDataDict.NODE_FEATURES_KEY,
+ out_field: str = AtomicDataDict.HAMILTONIAN_KEY,
+ overlap: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ):
+ super(HR2HK, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ self.device = device
+ self.overlap = overlap
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ assert idp.method == "e3tb", "The method of idp should be e3tb."
+ self.idp = idp
+
+ self.basis = self.idp.basis
+ self.idp.get_orbpair_maps()
+
+ self.edge_field = edge_field
+ self.node_field = node_field
+ self.out_field = out_field
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+
+ # construct bond wise hamiltonian block from obital pair wise node/edge features
+ # we assume the edge feature have the similar format as the node feature, which is reduced from orbitals index oj-oi with j>i
+
+ orbpair_hopping = data[self.edge_field]
+ orbpair_onsite = data.get(self.node_field)
+ bondwise_hopping = torch.zeros((len(orbpair_hopping), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.dtype, device=self.device)
+ bondwise_hopping.to(self.device)
+ bondwise_hopping.type(self.dtype)
+ onsite_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device)
+
+ ist = 0
+ for i,iorb in enumerate(self.idp.full_basis):
+ jst = 0
+ li = anglrMId[re.findall(r"[a-zA-Z]+", iorb)[0]]
+ for j,jorb in enumerate(self.idp.full_basis):
+ orbpair = iorb + "-" + jorb
+ lj = anglrMId[re.findall(r"[a-zA-Z]+", jorb)[0]]
+
+ # constructing hopping blocks
+ if iorb == jorb:
+ factor = 0.5
+ else:
+ factor = 1.0
+
+ if i <= j:
+ bondwise_hopping[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_hopping[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
+
+
+ # constructing onsite blocks
+ if self.overlap:
+ if iorb == jorb:
+ onsite_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = factor * torch.eye(2*li+1, dtype=self.dtype, device=self.device).reshape(1, 2*li+1, 2*lj+1).repeat(onsite_block.shape[0], 1, 1)
+ else:
+ if i <= j:
+ onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
+
+ jst += 2*lj+1
+ ist += 2*li+1
+ self.onsite_block = onsite_block
+ self.bondwise_hopping = bondwise_hopping
+
+
+ # R2K procedure can be done for all kpoint at once.
+ all_norb = self.idp.atom_norb[data[AtomicDataDict.ATOM_TYPE_KEY]].sum()
+ block = torch.zeros(data[AtomicDataDict.KPOINT_KEY].shape[0], all_norb, all_norb, dtype=self.dtype, device=self.device)
+ block = torch.complex(block, torch.zeros_like(block))
+
+ atom_id_to_indices = {}
+ ist = 0
+ for i, oblock in enumerate(onsite_block):
+ mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]]
+ masked_oblock = oblock[mask][:,mask]
+ block[:,ist:ist+masked_oblock.shape[0],ist:ist+masked_oblock.shape[1]] = masked_oblock.squeeze(0)
+ atom_id_to_indices[i] = slice(ist, ist+masked_oblock.shape[0])
+ ist += masked_oblock.shape[0]
+
+ for i, hblock in enumerate(bondwise_hopping):
+ iatom = data[AtomicDataDict.EDGE_INDEX_KEY][0][i]
+ jatom = data[AtomicDataDict.EDGE_INDEX_KEY][1][i]
+ iatom_indices = atom_id_to_indices[int(iatom)]
+ jatom_indices = atom_id_to_indices[int(jatom)]
+ imask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[iatom]]
+ jmask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[jatom]]
+ masked_hblock = hblock[imask][:,jmask]
+
+ block[:,iatom_indices,jatom_indices] += masked_hblock.squeeze(0).type_as(block) * \
+ torch.exp(-1j * 2 * torch.pi * (data[AtomicDataDict.KPOINT_KEY] @ data[AtomicDataDict.EDGE_CELL_SHIFT_KEY][i])).reshape(-1,1,1)
+
+ block = block + block.transpose(1,2).conj()
+ block = block.contiguous()
+
+ data[self.out_field] = block
+
+ return data
+
\ No newline at end of file
diff --git a/dptb/nn/nnsk.py b/dptb/nn/nnsk.py
new file mode 100644
index 00000000..4617ce0f
--- /dev/null
+++ b/dptb/nn/nnsk.py
@@ -0,0 +1,548 @@
+"""The file doing the process from the fitting net output sk formula parameters in node/edge feature to the tight binding two centre integrals parameters in node/edge feature.
+in: Data
+out Data
+
+basically a map from a matrix parameters to edge/node features, or strain mode's environment edge features
+"""
+
+import torch
+from dptb.utils.constants import h_all_types, anglrMId
+from typing import Tuple, Union, Dict
+from dptb.data.transforms import OrbitalMapper
+from dptb.data import AtomicDataDict
+import numpy as np
+import torch.nn as nn
+from .sktb import OnsiteFormula, bond_length_list, HoppingFormula
+from dptb.utils.constants import atomic_num_dict_r, atomic_num_dict
+from dptb.nn.hamiltonian import SKHamiltonian
+from dptb.utils.tools import j_loader
+
+class NNSK(torch.nn.Module):
+ name = "nnsk"
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp_sk: Union[OrbitalMapper, None]=None,
+ onsite: Dict={"method": "none"},
+ hopping: Dict={"method": "powerlaw", "rs":6.0, "w": 0.2},
+ overlap: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ transform: bool = True,
+ freeze: bool = False,
+ push: Dict=None,
+ std: float = 0.01,
+ **kwargs,
+ ) -> None:
+
+ super(NNSK, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ self.dtype = dtype
+ self.device = device
+
+ if basis is not None:
+ self.idp_sk = OrbitalMapper(basis, method="sktb", device=self.device)
+ if idp_sk is not None:
+ assert idp_sk.basis == self.idp_sk.basis, "The basis of idp and basis should be the same."
+ else:
+ assert idp_sk is not None, "Either basis or idp should be provided."
+ self.idp_sk = idp_sk
+
+ self.transform = transform
+ self.basis = self.idp_sk.basis
+ self.idp_sk.get_orbpair_maps()
+ self.idp_sk.get_skonsite_maps()
+ self.onsite_options = onsite
+ self.hopping_options = hopping
+ self.push = push
+ self.model_options = {
+ "nnsk":{
+ "onsite": onsite,
+ "hopping": hopping,
+ "freeze": freeze,
+ "push": push,
+ }
+ }
+
+ self.count_push = 0
+
+ # init_onsite, hopping, overlap formula
+
+ self.onsite_fn = OnsiteFormula(idp=self.idp_sk, functype=self.onsite_options["method"], dtype=dtype, device=device)
+ self.hopping_fn = HoppingFormula(functype=self.hopping_options["method"])
+ if overlap:
+ self.overlap_fn = HoppingFormula(functype=self.hopping_options["method"], overlap=True)
+
+ # init_param
+ #
+ hopping_param = torch.empty([len(self.idp_sk.bond_types), self.idp_sk.reduced_matrix_element, self.hopping_fn.num_paras], dtype=self.dtype, device=self.device)
+ nn.init.normal_(hopping_param, mean=0.0, std=std)
+ self.hopping_param = torch.nn.Parameter(hopping_param)
+ if overlap:
+ overlap_param = torch.empty([len(self.idp_sk.bond_types), self.idp_sk.reduced_matrix_element, self.hopping_fn.num_paras], dtype=self.dtype, device=self.device)
+ nn.init.normal_(overlap_param, mean=0.0, std=std)
+ self.overlap_param = torch.nn.Parameter(overlap_param)
+
+ if self.onsite_options["method"] == "strain":
+ self.onsite_param = None
+ elif self.onsite_options["method"] == "none":
+ self.onsite_param = None
+ elif self.onsite_options["method"] in ["NRL", "uniform"]:
+ onsite_param = torch.empty([len(self.idp_sk.type_names), self.idp_sk.n_onsite_Es, self.onsite_fn.num_paras], dtype=self.dtype, device=self.device)
+ nn.init.normal_(onsite_param, mean=0.0, std=std)
+ self.onsite_param = torch.nn.Parameter(onsite_param)
+ else:
+ raise NotImplementedError(f"The onsite method {self.onsite_options['method']} is not implemented.")
+
+ if self.onsite_options["method"] == "strain":
+ # AB [ss, sp, sd, ps, pp, pd, ds, dp, dd]
+ # AA [...]
+ # but need to map to all pairs and all orbital pairs like AB, AA, BB, BA for [ss, sp, sd, ps, pp, pd, ds, dp, dd]
+ # with this map: BA[sp, sd] = AB[ps, ds]
+ strain_param = torch.empty([len(self.idp_sk.bond_types), self.idp_sk.reduced_matrix_element, self.hopping_fn.num_paras], dtype=self.dtype, device=self.device)
+ nn.init.normal_(strain_param, mean=0.0, std=std)
+ self.strain_param = torch.nn.Parameter(strain_param)
+ # symmetrize the env for same atomic spices
+
+ self.hamiltonian = SKHamiltonian(idp_sk=self.idp_sk, onsite=True, dtype=self.dtype, device=self.device, strain=hasattr(self, "strain_param"))
+ if overlap:
+ self.overlap = SKHamiltonian(idp_sk=self.idp_sk, onsite=False, edge_field=AtomicDataDict.EDGE_OVERLAP_KEY, node_field=AtomicDataDict.NODE_OVERLAP_KEY, dtype=self.dtype, device=self.device)
+ self.idp = self.hamiltonian.idp
+
+ if freeze:
+ for (name, param) in self.named_parameters():
+ param.requires_grad = False
+
+ def push_decay(self, rs_thr: float=0., rc_thr: float=0., w_thr: float=0., period:int=100):
+ """Push the soft cutoff function
+
+ Parameters
+ ----------
+ rs_thr : float
+ the threshold step to push the rs
+ w_thr : float
+ the threshold step to push the w
+ """
+
+
+ if self.count_push // period > 0:
+ if abs(rs_thr) > 0:
+ self.hopping_options["rs"] += rs_thr
+ if abs(w_thr) > 0:
+ self.hopping_options["w"] += w_thr
+ if abs(rc_thr) > 0:
+ self.hopping_options["rc"] += rc_thr
+
+ self.model_options["nnsk"]["hopping"] = self.hopping_options
+
+ self.count_push = 0
+ else:
+ self.count_push += 1
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ # get the env and bond from the data
+ # calculate the sk integrals
+ # calculate the onsite
+ # calculate the hopping
+ # calculate the overlap
+ # return the data with updated edge/node features
+
+ # map the parameters to the edge/node/env features
+
+ # compute integrals from parameters using hopping and onsite clas
+
+ # symmetrize the bond for same atomic spices
+ # reflect_keys = np.array(list(self.idp_sk.pair_maps.keys()), dtype="str").reshape(len(self.idp_sk.full_basis), len(self.idp_sk.full_basis)).transpose(1,0).reshape(-1)
+ # params = 0.5 * self.hopping_param.data[self.idp_sk.transform_reduced_bond(torch.tensor(list(self.idp_sk._valid_set)), torch.tensor(list(self.idp_sk._valid_set)))]
+ # reflect_params = torch.zeros_like(params)
+ # for k, k_r in zip(self.idp_sk.pair_maps.keys(), reflect_keys):
+ # reflect_params[:,self.idp_sk.pair_maps[k],:] += params[:,self.idp_sk.pair_maps[k_r],:]
+ # self.hopping_param.data[self.idp_sk.transform_reduced_bond(torch.tensor(list(self.idp_sk._valid_set)), torch.tensor(list(self.idp_sk._valid_set)))] = \
+ # reflect_params + params
+
+ # if hasattr(self, "overlap"):
+ # params = 0.5 * self.overlap_param.data[self.idp_sk.transform_reduced_bond(torch.tensor(list(self.idp_sk._valid_set)), torch.tensor(list(self.idp_sk._valid_set)))]
+ # reflect_params = torch.zeros_like(params)
+ # for k, k_r in zip(self.idp_sk.pair_maps.keys(), reflect_keys):
+ # reflect_params[:,self.idp_sk.pair_maps[k],:] += params[:,self.idp_sk.pair_maps[k_r],:]
+ # self.overlap_param.data[self.idp_sk.transform_reduced_bond(torch.tensor(list(self.idp_sk._valid_set)), torch.tensor(list(self.idp_sk._valid_set)))] = \
+ # reflect_params + params
+
+ # # in strain case, all env pair need to be symmetrized
+ # if self.onsite_fn.functype == "strain":
+ # params = 0.5 * self.strain_param.data
+ # reflect_params = torch.zeros_like(params)
+ # for k, k_r in zip(self.idp_sk.pair_maps.keys(), reflect_keys):
+ # reflect_params[:,self.idp_sk.pair_maps[k],:] += params[:,self.idp_sk.pair_maps[k_r],:]
+ # self.strain_param.data = reflect_params + params
+
+ if self.push is not None:
+ if abs(self.push.get("rs_thr")) + abs(self.push.get("rc_thr")) + abs(self.push.get("w_thr")) > 0:
+ self.push_decay(**self.push)
+
+ reflective_bonds = np.array([self.idp_sk.bond_to_type["-".join(self.idp_sk.type_to_bond[i].split("-")[::-1])] for i in range(len(self.idp_sk.bond_types))])
+ params = self.hopping_param.data
+ reflect_params = params[reflective_bonds]
+ for k in self.idp_sk.orbpair_maps.keys():
+ iorb, jorb = k.split("-")
+ if iorb == jorb:
+ # This is to keep the symmetry of the hopping parameters for the same orbital pairs
+ # As-Bs = Bs-As; we need to do this because for different orbital pairs, we only have one set of parameters,
+ # eg. we only have As-Bp and Bs-Ap, but not Ap-Bs and Bp-As; and we will use Ap-Bs = Bs-Ap and Bp-As = As-Bp to calculate the hopping integral
+ self.hopping_param.data[:,self.idp_sk.orbpair_maps[k],:] = 0.5 * (params[:,self.idp_sk.orbpair_maps[k],:] + reflect_params[:,self.idp_sk.orbpair_maps[k],:])
+ if hasattr(self, "overlap"):
+ params = self.overlap_param.data
+ reflect_params = params[reflective_bonds]
+ for k in self.idp_sk.orbpair_maps.keys():
+ iorb, jorb = k.split("-")
+ if iorb == jorb:
+ self.overlap_param.data[:,self.idp_sk.orbpair_maps[k],:] = 0.5 * (params[:,self.idp_sk.orbpair_maps[k],:] + reflect_params[:,self.idp_sk.orbpair_maps[k],:])
+
+
+ data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)
+
+ # edge_number = data[AtomicDataDict.ATOMIC_NUMBERS_KEY][data[AtomicDataDict.EDGE_INDEX_KEY]].reshape(2, -1)
+ # edge_index = self.idp_sk.transform_reduced_bond(*edge_number)
+ edge_index = data[AtomicDataDict.EDGE_TYPE_KEY].flatten() # it is bond_type index, transform it to reduced bond index
+ edge_number = self.idp_sk.untransform_bond(edge_index).T
+ edge_index = self.idp_sk.transform_bond(*edge_number)
+
+ # the edge number is the atomic number of the two atoms in the bond.
+ # The bond length list is actually the nucli radius (unit of angstrom) at the atomic number.
+ # now this bond length list is only available for the first 83 elements.
+ assert (edge_number <= 83).all(), "The bond length list is only available for the first 83 elements."
+
+ r0 = 0.5*bond_length_list.type(self.dtype).to(self.device)[edge_number-1].sum(0)
+
+ data[AtomicDataDict.EDGE_FEATURES_KEY] = self.hopping_fn.get_skhij(
+ rij=data[AtomicDataDict.EDGE_LENGTH_KEY],
+ paraArray=self.hopping_param[edge_index], # [N_edge, n_pairs, n_paras],
+ **self.hopping_options,
+ r0=r0
+ ) # [N_edge, n_pairs]
+
+ if hasattr(self, "overlap"):
+ equal_orbpair = torch.zeros(self.idp_sk.reduced_matrix_element, dtype=self.dtype, device=self.device)
+ for orbpair_key, slices in self.idp_sk.orbpair_maps.items():
+ if orbpair_key.split("-")[0] == orbpair_key.split("-")[1]:
+ equal_orbpair[slices] = 1.0
+ # this paraconst is to make sure the overlap between the same orbital pairs of the save atom is 1.0
+ # this is taken from the formula of NRL-TB.
+ # the overlap tag now is only designed to be used in the NRL-TB case. In the future, we may need to change this.
+ paraconst = edge_number[0].eq(edge_number[1]).float().view(-1, 1) * equal_orbpair.unsqueeze(0)
+
+ data[AtomicDataDict.EDGE_OVERLAP_KEY] = self.overlap_fn.get_sksij(
+ rij=data[AtomicDataDict.EDGE_LENGTH_KEY],
+ paraArray=self.overlap_param[edge_index],
+ paraconst=paraconst,
+ **self.hopping_options,
+ r0=r0,
+ )
+
+ atomic_numbers = self.idp_sk.untransform_atom(data[AtomicDataDict.ATOM_TYPE_KEY].flatten())
+ if self.onsite_fn.functype == "NRL":
+ data = AtomicDataDict.with_env_vectors(data, with_lengths=True)
+ data[AtomicDataDict.NODE_FEATURES_KEY] = self.onsite_fn.get_skEs(
+ # atomic_numbers=data[AtomicDataDict.ATOMIC_NUMBERS_KEY],
+ atomic_numbers=atomic_numbers,
+ onsitenv_index=data[AtomicDataDict.ONSITENV_INDEX_KEY],
+ onsitenv_length=data[AtomicDataDict.ONSITENV_LENGTH_KEY],
+ nn_onsite_paras=self.onsite_param,
+ **self.onsite_options,
+ )
+ else:
+ data[AtomicDataDict.NODE_FEATURES_KEY] = self.onsite_fn.get_skEs(
+ atomic_numbers=atomic_numbers,
+ nn_onsite_paras=self.onsite_param
+ )
+
+ # if hasattr(self, "overlap"):
+ # data[AtomicDataDict.NODE_OVERLAP_KEY] = torch.ones_like(data[AtomicDataDict.NODE_OVERLAP_KEY])
+
+ # compute strain
+ if self.onsite_fn.functype == "strain":
+ data = AtomicDataDict.with_onsitenv_vectors(data, with_lengths=True)
+ onsitenv_number = self.idp_sk.untransform_atom(data[AtomicDataDict.ATOM_TYPE_KEY].flatten())[data[AtomicDataDict.ONSITENV_INDEX_KEY]].reshape(2, -1)
+ onsitenv_index = self.idp_sk.transform_bond(*onsitenv_number)
+ # reflect_index = self.idp_sk.transform_bond(*onsitenv_number.flip(0))
+ # onsitenv_index[onsitenv_index<0] = reflect_index[onsitenv_index<0] + len(self.idp_sk.reduced_bond_types)
+ # reflect_params = torch.zeros_like(self.strain_param)
+ # for k, k_r in zip(self.idp_sk.pair_maps.keys(), reflect_keys):
+ # reflect_params[:,self.idp_sk.pair_maps[k],:] += self.strain_param[:,self.idp_sk.pair_maps[k_r],:]
+ # onsitenv_params = torch.cat([self.strain_param,
+ # reflect_params], dim=0)
+
+ r0 = 0.5*bond_length_list.type(self.dtype).to(self.device)[onsitenv_number-1].sum(0)
+ assert (edge_number <= 83).all(), "The bond length list is only available for the first 83 elements."
+ onsitenv_params = self.hopping_fn.get_skhij(
+ rij=data[AtomicDataDict.ONSITENV_LENGTH_KEY],
+ paraArray=self.strain_param[onsitenv_index], # [N_edge, n_pairs, n_paras],
+ r0=r0,
+ **self.onsite_options,
+ ) # [N_edge, n_pairs]
+
+ data[AtomicDataDict.ONSITENV_FEATURES_KEY] = onsitenv_params
+
+ # sk param to hamiltonian and overlap
+ if self.transform:
+ data = self.hamiltonian(data)
+ if hasattr(self, "overlap"):
+ data = self.overlap(data)
+
+ return data
+
+ @classmethod
+ def from_reference(
+ cls,
+ checkpoint: str,
+ basis: Dict[str, Union[str, list]]=None,
+ onsite: Dict=None,
+ hopping: Dict=None,
+ overlap: bool=None,
+ dtype: Union[str, torch.dtype]=None,
+ device: Union[str, torch.device]=None,
+ push: Dict=None,
+ freeze: bool = None,
+ std: float = 0.01,
+ **kwargs,
+ ):
+ # the mapping from the parameters of the ref_model and the current model can be found using
+ # reference model's idp and current idp
+
+ common_options = {
+ "dtype": dtype,
+ "device": device,
+ "basis": basis,
+ "overlap": overlap,
+ }
+
+ nnsk = {
+ "onsite": onsite,
+ "hopping": hopping,
+ "freeze": freeze,
+ "push": push,
+ "std": std
+ }
+
+
+ if checkpoint.split(".")[-1] == "json":
+ for k,v in common_options.items():
+ assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."
+ for k,v in nnsk.items():
+ assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."
+
+ v1_model = j_loader(checkpoint)
+ model = cls._from_model_v1(
+ v1_model=v1_model,
+ **nnsk,
+ **common_options,
+ )
+
+ del v1_model
+
+ else:
+ f = torch.load(checkpoint, map_location=device)
+ for k,v in common_options.items():
+ if v is None:
+ common_options[k] = f["config"]["common_options"][k]
+ for k,v in nnsk.items():
+ if v is None:
+ nnsk[k] = f["config"]["model_options"]["nnsk"][k]
+
+ model = cls(**common_options, **nnsk)
+
+ if f["config"]["common_options"]["basis"] == common_options["basis"] and \
+ f["config"]["model_options"] == model.model_options:
+ model.load_state_dict(f["model_state_dict"])
+ else:
+ #TODO: handle the situation when ref_model config is not the same as the current model
+ # load hopping
+ ref_idp = OrbitalMapper(f["config"]["common_options"]["basis"], method="sktb")
+ idp = OrbitalMapper(common_options["basis"], method="sktb")
+
+ ref_idp.get_orbpair_maps()
+ idp.get_orbpair_maps()
+
+
+ params = f["model_state_dict"]["hopping_param"]
+ for bond in ref_idp.bond_types:
+ if bond in idp.bond_types:
+ iasym, jasym = bond.split("-")
+ for ref_forbpair in ref_idp.orbpair_maps.keys():
+ rfiorb, rfjorb = ref_forbpair.split("-")
+ riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb]
+ fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb)
+ if fiorb is not None and fjorb is not None:
+ sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}")
+ b = bond
+ if sli is None:
+ sli = idp.orbpair_maps.get(f"{fjorb}-{fiorb}")
+ b = f"{jasym}-{iasym}"
+ model.hopping_param.data[idp.bond_to_type[b],sli] = \
+ params[ref_idp.bond_to_type[b],ref_idp.orbpair_maps[ref_forbpair]]
+
+ # load overlap
+ if hasattr(model, "overlap_param") and f["model_state_dict"].get("overlap_param") != None:
+ params = f["model_state_dict"]["overlap_param"]
+ for bond in ref_idp.bond_types:
+ if bond in idp.bond_types:
+ iasym, jasym = bond.split("-")
+ for ref_forbpair in ref_idp.orbpair_maps.keys():
+ rfiorb, rfjorb = ref_forbpair.split("-")
+ riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb]
+ fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb)
+ if fiorb is not None and fjorb is not None:
+ sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}")
+ b = bond
+ if sli is None:
+ sli = idp.orbpair_maps.get(f"{fjorb}-{fiorb}")
+ b = f"{jasym}-{iasym}"
+ model.overlap_param.data[idp.bond_to_type[b],sli] = \
+ params[ref_idp.bond_to_type[b],ref_idp.orbpair_maps[ref_forbpair]]
+
+ # load onsite
+ if model.onsite_param != None and f["model_state_dict"].get("onsite_param") != None:
+ params = f["model_state_dict"]["onsite_param"]
+ ref_idp.get_skonsite_maps()
+ idp.get_skonsite_maps()
+ for asym in ref_idp.type_names:
+ if asym in idp.type_names:
+ for ref_forb in ref_idp.skonsite_maps.keys():
+ rorb = ref_idp.full_basis_to_basis[asym][ref_forb]
+ forb = idp.basis_to_full_basis[asym].get(rorb)
+ if forb is not None:
+ model.onsite_param.data[idp.chemical_symbol_to_type[asym],idp.skonsite_maps[forb]] = \
+ params[ref_idp.chemical_symbol_to_type[asym],ref_idp.skonsite_maps[ref_forb]]
+
+ # load strain
+ if hasattr(model, "strain_param") and f["model_state_dict"].get("strain_param") != None:
+ params = f["model_state_dict"]["strain_param"]
+ for bond in ref_idp.bond_types:
+ if bond in idp.bond_types:
+ iasym, jasym = bond.split("-")
+ for ref_forbpair in ref_idp.orbpair_maps.keys():
+ rfiorb, rfjorb = ref_forbpair.split("-")
+ riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb]
+ fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb)
+ if fiorb is not None and fjorb is not None:
+ sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}")
+ b = bond
+ if sli is None:
+ sli = idp.orbpair_maps.get(f"{fjorb}-{fiorb}")
+ b = f"{jasym}-{iasym}"
+ model.strain_param.data[idp.bond_to_type[b], sli] = \
+ params[ref_idp.bond_to_type[b],ref_idp.orbpair_maps[ref_forbpair]]
+
+ del f
+
+ if freeze:
+ for (name, param) in model.named_parameters():
+ param.requires_grad = False
+ else:
+ param.requires_grad = True # in case initilizing some frozen checkpoint while with current freeze setted as False
+
+ return model
+
+ @classmethod
+ def _from_model_v1(
+ cls,
+ v1_model: dict,
+ basis: Dict[str, Union[str, list]]=None,
+ idp_sk: Union[OrbitalMapper, None]=None,
+ onsite: Dict={"method": "none"},
+ hopping: Dict={"method": "powerlaw", "rs":6.0, "w": 0.2},
+ overlap: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ std: float = 0.01,
+ ):
+ # could support json file and .pth file checkpoint of nnsk
+
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ dtype = dtype
+ device = device
+
+ if basis is not None:
+ assert idp_sk is None
+ idp_sk = OrbitalMapper(basis, method="sktb")
+ else:
+ assert idp_sk is not None
+
+
+ basis = idp_sk.basis
+ idp_sk.get_orbpair_maps()
+ idp_sk.get_skonsite_maps()
+
+ nnsk_model = cls(basis=basis, idp_sk=idp_sk, dtype=dtype, device=device, onsite=onsite, hopping=hopping, overlap=overlap, std=std)
+
+ onsite_param = v1_model["onsite"]
+ hopping_param = v1_model["hopping"]
+
+ assert len(hopping) > 0, "The hopping parameters should be provided."
+
+ # load hopping params
+ for orbpair, skparam in hopping_param.items():
+ skparam = torch.tensor(skparam, dtype=dtype, device=device)
+ skparam[0] *= 13.605662285137 * 2
+ iasym, jasym, iorb, jorb, num = list(orbpair.split("-"))
+ num = int(num)
+ ian, jan = torch.tensor(atomic_num_dict[iasym]), torch.tensor(atomic_num_dict[jasym])
+ fiorb, fjorb = idp_sk.basis_to_full_basis[iasym][iorb], idp_sk.basis_to_full_basis[jasym][jorb]
+
+
+ if idp_sk.full_basis.index(fiorb) <= idp_sk.full_basis.index(fjorb):
+ nline = idp_sk.transform_bond(iatomic_numbers=ian, jatomic_numbers=jan)
+ nidx = idp_sk.orbpair_maps[f"{fiorb}-{fjorb}"].start + num
+ else:
+ nline = idp_sk.transform_bond(iatomic_numbers=jan, jatomic_numbers=ian)
+ nidx = idp_sk.orbpair_maps[f"{fjorb}-{fiorb}"].start + num
+
+ nnsk_model.hopping_param.data[nline, nidx] = skparam
+ if ian != jan and fiorb == fjorb:
+ nline = idp_sk.transform_bond(iatomic_numbers=jan, jatomic_numbers=ian)
+ nnsk_model.hopping_param.data[nline, nidx] = skparam
+
+ # load onsite params, differently with onsite mode
+ if onsite["method"] == "strain":
+ for orbpair, skparam in onsite_param.items():
+ skparam = torch.tensor(skparam, dtype=dtype, device=device)
+ skparam[0] *= 13.605662285137 * 2
+ iasym, jasym, iorb, jorb, num = list(orbpair.split("-"))
+ num = int(num)
+ ian, jan = torch.tensor(atomic_num_dict[iasym]), torch.tensor(atomic_num_dict[jasym])
+
+ fiorb, fjorb = idp_sk.basis_to_full_basis[iasym][iorb], idp_sk.basis_to_full_basis[iasym][jorb]
+
+ nline = idp_sk.transform_bond(iatomic_numbers=ian, jatomic_numbers=jan)
+ if idp_sk.full_basis.index(fiorb) <= idp_sk.full_basis.index(fjorb):
+ nidx = idp_sk.orbpair_maps[f"{fiorb}-{fjorb}"].start + num
+ else:
+ nidx = idp_sk.orbpair_maps[f"{fjorb}-{fiorb}"].start + num
+
+ nnsk_model.strain_param.data[nline, nidx] = skparam
+
+ # if ian == jan:
+ # nidx = idp_sk.pair_maps[f"{fjorb}-{fiorb}"].start + num
+ # nnsk_model.strain_param.data[nline, nidx] = skparam
+
+ elif onsite["method"] == "none":
+ pass
+ else:
+ for orbon, skparam in onsite_param.items():
+ skparam = torch.tensor(skparam, dtype=dtype, device=device)
+ skparam *= 13.605662285137 * 2
+ iasym, iorb, num = list(orbon.split("-"))
+ num = int(num)
+ ian = torch.tensor(atomic_num_dict[iasym])
+ fiorb = idp_sk.basis_to_full_basis[iasym][iorb]
+
+ nline = idp_sk.transform_atom(atomic_numbers=ian)
+ nidx = idp_sk.skonsite_maps[fiorb].start + num
+
+ nnsk_model.onsite_param.data[nline, nidx] = skparam
+
+ return nnsk_model
+
diff --git a/dptb/nn/norm.py b/dptb/nn/norm.py
new file mode 100644
index 00000000..2f22af10
--- /dev/null
+++ b/dptb/nn/norm.py
@@ -0,0 +1,181 @@
+import torch
+from torch import nn
+
+from e3nn import o3
+from e3nn.util.jit import compile_mode
+from torch_scatter import scatter_mean
+
+@compile_mode("unsupported")
+class TypeNorm(nn.Module):
+ """Batch normalization for orthonormal representations
+
+ It normalizes by the norm of the representations.
+ Note that the norm is invariant only for orthonormal representations.
+ Irreducible representations `wigner_D` are orthonormal.
+
+ Parameters
+ ----------
+ irreps : `o3.Irreps`
+ representation
+
+ eps : float
+ avoid division by zero when we normalize by the variance
+
+ momentum : float
+ momentum of the running average
+
+ affine : bool
+ do we have weight and bias parameters
+
+ reduce : {'mean', 'max'}
+ method used to reduce
+
+ """
+
+ def __init__(self, irreps, eps=1e-5, momentum=0.1, affine=True, num_type=1, reduce="mean", normalization="component"):
+ super().__init__()
+
+ self.irreps = o3.Irreps(irreps)
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.num_type = num_type
+
+ num_scalar = sum(mul for mul, ir in self.irreps if ir.is_scalar())
+ num_features = self.irreps.num_irreps
+
+ self.register_buffer("running_mean", torch.zeros(num_type, num_scalar))
+ self.register_buffer("running_var", torch.ones(num_type, num_features))
+
+ if affine:
+ self.weight = nn.Parameter(torch.ones(num_type, num_features))
+ self.bias = nn.Parameter(torch.zeros(num_type, num_scalar))
+ else:
+ self.register_parameter("weight", None)
+ self.register_parameter("bias", None)
+
+ assert isinstance(reduce, str), "reduce should be passed as a string value"
+ assert reduce in ["mean", "max"], "reduce needs to be 'mean' or 'max'"
+ self.reduce = reduce
+
+ assert normalization in ["norm", "component"], "normalization needs to be 'norm' or 'component'"
+ self.normalization = normalization
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps}, momentum={self.momentum})"
+
+ def _roll_avg(self, curr, update):
+ mask = (update.norm(dim=-1) > 1e-7)
+ out = curr.clone()
+ out[mask] = (1 - self.momentum) * curr[mask] + self.momentum * update[mask].detach()
+ return out
+
+
+ def forward(self, input, input_type):
+ """evaluate
+
+ Parameters
+ ----------
+ input : `torch.Tensor`
+ tensor of shape ``(batch, ..., irreps.dim)``
+ input_type : `torch.Tensor`
+ tensor of shape ``(batch)``
+
+ Returns
+ -------
+ `torch.Tensor`
+ tensor of shape ``(batch, ..., irreps.dim)``
+ """
+
+ batch, *size, dim = input.shape
+ input = input.reshape(batch, -1, dim) # [batch, sample, stacked features]
+
+ if self.training:
+ new_means = []
+ new_vars = []
+
+ fields = []
+ ix = 0
+ irm = 0
+ irv = 0
+ iw = 0
+ ib = 0
+
+ for mul, ir in self.irreps:
+ d = ir.dim
+ field = input[:, :, ix : ix + mul * d] # [batch, sample, mul * repr]
+ ix += mul * d
+
+ # [batch, sample, mul, repr]
+ field = field.reshape(batch, -1, mul, d)
+
+ if ir.is_scalar(): # scalars
+ if self.training:
+ field_mean = field.mean(1).reshape(batch, mul) # [batch, mul]
+ field_mean = scatter_mean(field_mean, input_type, dim=0, dim_size=self.num_type) # [num_type, mul]
+ new_means.append(self._roll_avg(self.running_mean[:, irm : irm + mul], field_mean))
+ else:
+ field_mean = self.running_mean[:, irm : irm + mul]
+ irm += mul
+
+ # [batch, sample, mul, repr]
+ field = field - field_mean.reshape(-1, 1, mul, 1)[input_type]
+
+ if self.training:
+ if self.normalization == "norm":
+ field_norm = field.pow(2).sum(3) # [batch, sample, mul]
+ elif self.normalization == "component":
+ field_norm = field.pow(2).mean(3) # [batch, sample, mul]
+ else:
+ raise ValueError("Invalid normalization option {}".format(self.normalization))
+
+ if self.reduce == "mean":
+ field_norm = field_norm.mean(1) # [batch, mul]
+ elif self.reduce == "max":
+ field_norm = field_norm.max(1).values # [batch, mul]
+ else:
+ raise ValueError("Invalid reduce option {}".format(self.reduce))
+
+ field_norm = scatter_mean(field_norm, input_type, dim=0, dim_size=self.num_type) # [num_type, mul]
+ new_vars.append(self._roll_avg(self.running_var[:, irv : irv + mul], field_norm))
+ else:
+ field_norm = self.running_var[:, irv : irv + mul]
+ irv += mul
+
+ field_norm = (field_norm + self.eps).pow(-0.5) # [(batch,) mul]
+
+ if self.affine:
+ weight = self.weight[:, iw : iw + mul] # [mul]
+ iw += mul
+
+ field_norm = field_norm * weight # [num_type, mul]
+
+ field = field * field_norm.reshape(-1, 1, mul, 1)[input_type] # [batch, sample, mul, repr]
+
+ if self.affine and ir.is_scalar(): # scalars
+ bias = self.bias[:, ib : ib + mul] # [mul]
+ ib += mul
+ field += bias.reshape(-1, 1, mul, 1)[input_type] # [batch, sample, mul, repr]
+
+ fields.append(field.reshape(batch, -1, mul * d)) # [batch, sample, mul * repr]
+
+ if ix != dim:
+ fmt = "`ix` should have reached input.size(-1) ({}), but it ended at {}"
+ msg = fmt.format(dim, ix)
+ raise AssertionError(msg)
+
+ if self.training:
+ assert irm == self.running_mean.size(-1)
+ assert irv == self.running_var.size(-1)
+ if self.affine:
+ assert iw == self.weight.size(-1)
+ assert ib == self.bias.size(-1)
+
+ if self.training:
+ if len(new_means) > 0:
+ torch.cat(new_means, dim=-1, out=self.running_mean)
+ if len(new_vars) > 0:
+ torch.cat(new_vars, dim=-1, out=self.running_var)
+
+ output = torch.cat(fields, dim=2) # [batch, sample, stacked features]
+ return output.reshape(batch, *size, dim)
\ No newline at end of file
diff --git a/dptb/nn/radial_basis.py b/dptb/nn/radial_basis.py
new file mode 100644
index 00000000..9278ce08
--- /dev/null
+++ b/dptb/nn/radial_basis.py
@@ -0,0 +1,140 @@
+from typing import Optional
+import math
+
+import torch
+
+from torch import nn
+
+from e3nn.math import soft_one_hot_linspace
+from e3nn.util.jit import compile_mode
+
+
+@compile_mode("trace")
+class e3nn_basis(nn.Module):
+ r_max: float
+ r_min: float
+ e3nn_basis_name: str
+ num_basis: int
+
+ def __init__(
+ self,
+ r_max: float,
+ r_min: Optional[float] = None,
+ e3nn_basis_name: str = "gaussian",
+ num_basis: int = 8,
+ ):
+ super().__init__()
+ self.r_max = r_max
+ self.r_min = r_min if r_min is not None else 0.0
+ self.e3nn_basis_name = e3nn_basis_name
+ self.num_basis = num_basis
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return soft_one_hot_linspace(
+ x,
+ start=self.r_min,
+ end=self.r_max,
+ number=self.num_basis,
+ basis=self.e3nn_basis_name,
+ cutoff=True,
+ )
+
+ def _make_tracing_inputs(self, n: int):
+ return [{"forward": (torch.randn(5, 1),)} for _ in range(n)]
+
+
+class BesselBasis(nn.Module):
+ r_max: float
+ prefactor: float
+
+ def __init__(self, r_max, num_basis=8, trainable=True):
+ r"""Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
+
+
+ Parameters
+ ----------
+ r_max : float
+ Cutoff radius
+
+ num_basis : int
+ Number of Bessel Basis functions
+
+ trainable : bool
+ Train the :math:`n \pi` part or not.
+ """
+ super(BesselBasis, self).__init__()
+
+ self.trainable = trainable
+ self.num_basis = num_basis
+
+ self.r_max = float(r_max)
+ self.prefactor = 2.0 / self.r_max
+
+ bessel_weights = (
+ torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi
+ )
+ if self.trainable:
+ self.bessel_weights = nn.Parameter(bessel_weights)
+ else:
+ self.register_buffer("bessel_weights", bessel_weights)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Evaluate Bessel Basis for input x.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input
+ """
+ numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.r_max)
+
+ return self.prefactor * (numerator / x.unsqueeze(-1))
+
+
+def gaussian_smearing(distances, offset, widths, centered=False):
+ if not centered:
+ # compute width of Gaussian functions (using an overlap of 1 STDDEV)
+ coeff = -0.5 / torch.pow(widths, 2)
+ # Use advanced indexing to compute the individual components
+ diff = distances[..., None] - offset
+ else:
+ # if Gaussian functions are centered, use offsets to compute widths
+ coeff = -0.5 / torch.pow(offset, 2)
+ # if Gaussian functions are centered, no offset is subtracted
+ diff = distances[..., None]
+ # compute smear distance values
+ gauss = torch.exp(coeff * torch.pow(diff, 2))
+ return gauss
+
+
+class GaussianBasis(nn.Module):
+ def __init__(
+ self, start=0.0, stop=5.0, n_gaussians=50, centered=False, trainable=False
+ ):
+ super(GaussianBasis, self).__init__()
+ # compute offset and width of Gaussian functions
+ offset = torch.linspace(start, stop, n_gaussians)
+ widths = torch.Tensor((offset[1] - offset[0]) * torch.ones_like(offset)) # FloatTensor
+ if trainable:
+ self.width = nn.Parameter(widths)
+ self.offsets = nn.Parameter(offset)
+ else:
+ self.register_buffer("width", widths)
+ self.register_buffer("offsets", offset)
+ self.centered = centered
+
+ def forward(self, distances):
+ """Compute smeared-gaussian distance values.
+
+ Args:
+ distances (torch.Tensor): interatomic distance values of
+ (N_b x N_at x N_nbh) shape.
+
+ Returns:
+ torch.Tensor: layer output of (N_b x N_at x N_nbh x N_g) shape.
+
+ """
+ return gaussian_smearing(
+ distances, self.offsets, self.width, centered=self.centered
+ )
diff --git a/dptb/nn/rescale.py b/dptb/nn/rescale.py
new file mode 100644
index 00000000..168e8351
--- /dev/null
+++ b/dptb/nn/rescale.py
@@ -0,0 +1,526 @@
+import math
+import torch
+from torch_runstats.scatter import scatter
+from dptb.data import _keys
+import logging
+from typing import Optional, List, Union
+import torch.nn.functional
+from e3nn.o3 import Linear
+from e3nn.util.jit import compile_mode
+from dptb.data import AtomicDataDict
+import e3nn.o3 as o3
+
+class PerSpeciesScaleShift(torch.nn.Module):
+ """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters.
+
+ Args:
+ field: the per-atom field to scale/shift.
+ num_types: the number of types in the model.
+ shifts: the initial shifts to use, one per atom type.
+ scales: the initial scales to use, one per atom type.
+ arguments_in_dataset_units: if ``True``, says that the provided shifts/scales are in dataset
+ units (in which case they will be rescaled appropriately by any global rescaling later
+ applied to the model); if ``False``, the provided shifts/scales will be used without modification.
+
+ For example, if identity shifts/scales of zeros and ones are provided, this should be ``False``.
+ But if scales/shifts computed from the training data are used, and are thus in dataset units,
+ this should be ``True``.
+ out_field: the output field; defaults to ``field``.
+ """
+
+ field: str
+ out_field: str
+ scales_trainble: bool
+ shifts_trainable: bool
+ has_scales: bool
+ has_shifts: bool
+
+ def __init__(
+ self,
+ field: str,
+ num_types: int,
+ shifts: Optional[List[float]],
+ scales: Optional[List[float]],
+ out_field: Optional[str] = None,
+ scales_trainable: bool = False,
+ shifts_trainable: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.num_types = num_types
+ self.field = field
+ self.out_field = f"shifted_{field}" if out_field is None else out_field
+
+ self.has_shifts = shifts is not None
+ if shifts is not None:
+ shifts = torch.as_tensor(shifts, dtype=torch.get_default_dtype())
+ if len(shifts.reshape([-1])) == 1:
+ shifts = torch.ones(num_types) * shifts
+ assert shifts.shape == (num_types,), f"Invalid shape of shifts {shifts}"
+ self.shifts_trainable = shifts_trainable
+ if shifts_trainable:
+ self.shifts = torch.nn.Parameter(shifts)
+ else:
+ self.register_buffer("shifts", shifts)
+
+ self.has_scales = scales is not None
+ if scales is not None:
+ scales = torch.as_tensor(scales, dtype=torch.get_default_dtype())
+ if len(scales.reshape([-1])) == 1:
+ scales = torch.ones(num_types) * scales
+ assert scales.shape == (num_types,), f"Invalid shape of scales {scales}"
+ self.scales_trainable = scales_trainable
+ if scales_trainable:
+ self.scales = torch.nn.Parameter(scales)
+ else:
+ self.register_buffer("scales", scales)
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+
+ if not (self.has_scales or self.has_shifts):
+ return data
+
+ species_idx = data[AtomicDataDict.ATOM_TYPE_KEY]
+ in_field = data[self.field]
+ assert len(in_field) == len(
+ species_idx
+ ), "in_field doesnt seem to have correct per-atom shape"
+ if self.has_scales:
+ in_field = self.scales[species_idx].view(-1, 1) * in_field
+ if self.has_shifts:
+ in_field = self.shifts[species_idx].view(-1, 1) + in_field
+ data[self.out_field] = in_field
+ return data
+
+ # def update_for_rescale(self, rescale_module):
+ # if hasattr(rescale_module, "related_scale_keys"):
+ # if self.out_field not in rescale_module.related_scale_keys:
+ # return
+ # if self.arguments_in_dataset_units and rescale_module.has_scale:
+ # logging.debug(
+ # f"PerSpeciesScaleShift's arguments were in dataset units; rescaling:\n "
+ # f"Original scales: {TypeMapper.format(self.scales, self.type_names) if self.has_scales else 'n/a'} "
+ # f"shifts: {TypeMapper.format(self.shifts, self.type_names) if self.has_shifts else 'n/a'}"
+ # )
+ # with torch.no_grad():
+ # if self.has_scales:
+ # self.scales.div_(rescale_module.scale_by)
+ # if self.has_shifts:
+ # self.shifts.div_(rescale_module.scale_by)
+ # logging.debug(
+ # f" New scales: {TypeMapper.format(self.scales, self.type_names) if self.has_scales else 'n/a'} "
+ # f"shifts: {TypeMapper.format(self.shifts, self.type_names) if self.has_shifts else 'n/a'}"
+ # )
+
+class PerEdgeSpeciesScaleShift(torch.nn.Module):
+ """Sum edgewise energies.
+
+ Includes optional per-species-pair edgewise energy scales.
+ """
+
+ field: str
+ out_field: str
+ scales_trainble: bool
+ shifts_trainable: bool
+ has_scales: bool
+ has_shifts: bool
+
+ def __init__(
+ self,
+ field: str,
+ num_types: int,
+ shifts: Optional[List[float]],
+ scales: Optional[List[float]],
+ out_field: Optional[str] = None,
+ scales_trainable: bool = False,
+ shifts_trainable: bool = False,
+ **kwargs,
+ ):
+ """Sum edges into nodes."""
+ super(PerEdgeSpeciesScaleShift, self).__init__()
+ self.num_types = num_types
+ self.field = field
+ self.out_field = f"shifted_{field}" if out_field is None else out_field
+
+ self.has_shifts = shifts is not None
+ self.has_scales = scales is not None
+ if scales is not None:
+ scales = torch.as_tensor(scales, dtype=torch.get_default_dtype())
+ if len(scales.reshape([-1])) == 1:
+ scales = torch.ones(num_types, num_types) * scales
+ assert scales.shape == (num_types, num_types,), f"Invalid shape of scales {scales}"
+ self.scales_trainable = scales_trainable
+ if scales_trainable:
+ self.scales = torch.nn.Parameter(scales)
+ else:
+ self.register_buffer("scales", scales)
+
+ if shifts is not None:
+ shifts = torch.as_tensor(shifts, dtype=torch.get_default_dtype())
+ if len(shifts.reshape([-1])) == 1:
+ shifts = torch.ones(num_types, num_types) * shifts
+ assert shifts.shape == (num_types, num_types,), f"Invalid shape of shifts {shifts}"
+ self.shifts_trainable = shifts_trainable
+ if shifts_trainable:
+ self.shifts = torch.nn.Parameter(shifts)
+ else:
+ self.register_buffer("shifts", shifts)
+
+
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+
+ if not (self.has_scales or self.has_shifts):
+ return data
+
+ edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]
+ edge_neighbor = data[AtomicDataDict.EDGE_INDEX_KEY][1]
+
+ species_idx = data[AtomicDataDict.ATOM_TYPE_KEY].flatten()
+ center_species = species_idx[edge_center]
+ neighbor_species = species_idx[edge_neighbor]
+ in_field = data[self.field]
+
+ assert len(in_field) == len(
+ edge_center
+ ), "in_field doesnt seem to have correct per-edge shape"
+
+
+ if self.has_scales:
+ in_field = self.scales[center_species, neighbor_species].view(-1, 1) * in_field
+ if self.has_shifts:
+ in_field = self.shifts[center_species, neighbor_species].view(-1, 1) + in_field
+
+ data[self.out_field] = in_field
+
+ return data
+
+class E3PerEdgeSpeciesScaleShift(torch.nn.Module):
+ """Sum edgewise energies.
+
+ Includes optional per-species-pair edgewise energy scales.
+ """
+
+ field: str
+ out_field: str
+ scales_trainble: bool
+ shifts_trainable: bool
+ has_scales: bool
+ has_shifts: bool
+
+ def __init__(
+ self,
+ field: str,
+ num_types: int,
+ irreps_in,
+ shifts: Optional[torch.Tensor],
+ scales: Optional[torch.Tensor],
+ out_field: Optional[str] = None,
+ scales_trainable: bool = False,
+ shifts_trainable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+ """Sum edges into nodes."""
+ super(E3PerEdgeSpeciesScaleShift, self).__init__()
+ self.num_types = num_types
+ self.field = field
+ self.out_field = f"shifted_{field}" if out_field is None else out_field
+ self.irreps_in = irreps_in
+ self.num_scalar = 0
+ self.device = device
+ self.dtype = dtype
+ self.shift_index = []
+ self.scale_index = []
+
+ start = 0
+ start_scalar = 0
+ for mul, ir in irreps_in:
+ if str(ir) == "0e":
+ self.num_scalar += mul
+ self.shift_index += list(range(start_scalar, start_scalar + mul))
+ start_scalar += mul
+ else:
+ self.shift_index += [-1] * mul * ir.dim
+
+ for _ in range(mul):
+ self.scale_index += [start] * ir.dim
+ start += 1
+
+ self.shift_index = torch.as_tensor(self.shift_index, dtype=torch.long, device=device)
+ self.scale_index = torch.as_tensor(self.scale_index, dtype=torch.long, device=device)
+
+ self.has_shifts = shifts is not None
+ self.has_scales = scales is not None
+ if scales is not None:
+ scales = torch.as_tensor(scales, dtype=self.dtype, device=device)
+ if len(scales.reshape(-1)) == 1:
+ scales = scales * torch.ones(num_types*num_types, self.irreps_in.num_irreps, dtype=self.dtype, device=self.device)
+ assert scales.shape == (num_types*num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}"
+ self.scales_trainable = scales_trainable
+ if scales_trainable:
+ self.scales = torch.nn.Parameter(scales)
+ else:
+ self.register_buffer("scales", scales)
+
+ if shifts is not None:
+ shifts = torch.as_tensor(shifts, dtype=self.dtype, device=device)
+ if len(shifts.reshape(-1)) == 1:
+ shifts = shifts * torch.ones(num_types*num_types, self.num_scalar, dtype=self.dtype, device=self.device)
+ assert shifts.shape == (num_types*num_types, self.num_scalar), f"Invalid shape of shifts {shifts}"
+ self.shifts_trainable = shifts_trainable
+ if shifts_trainable:
+ self.shifts = torch.nn.Parameter(shifts)
+ else:
+ self.register_buffer("shifts", shifts)
+
+ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None):
+ self.has_scales = scales is not None or self.has_scales
+ if scales is not None:
+ assert scales.shape == (self.num_types*self.num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}"
+ if self.scales_trainable:
+ self.scales = torch.nn.Parameter(scales)
+ else:
+ self.register_buffer("scales", scales)
+
+ self.has_shifts = shifts is not None or self.has_shifts
+ if shifts is not None:
+ assert shifts.shape == (self.num_types*self.num_types, self.num_scalar), f"Invalid shape of shifts {shifts}"
+ if self.shifts_trainable:
+ self.shifts = torch.nn.Parameter(shifts)
+ else:
+ self.register_buffer("shifts", shifts)
+
+
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+
+ if not (self.has_scales or self.has_shifts):
+ return data
+
+ edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]
+
+ species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()
+ in_field = data[self.field]
+
+ assert len(in_field) == len(
+ edge_center
+ ), "in_field doesnt seem to have correct per-edge shape"
+
+ if self.has_scales:
+ in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field
+ if self.has_shifts:
+ shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar)
+ in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]
+
+ data[self.out_field] = in_field
+
+ return data
+
+class E3PerSpeciesScaleShift(torch.nn.Module):
+ """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters.
+
+ Args:
+ field: the per-atom field to scale/shift.
+ num_types: the number of types in the model.
+ shifts: the initial shifts to use, one per atom type.
+ scales: the initial scales to use, one per atom type.
+ arguments_in_dataset_units: if ``True``, says that the provided shifts/scales are in dataset
+ units (in which case they will be rescaled appropriately by any global rescaling later
+ applied to the model); if ``False``, the provided shifts/scales will be used without modification.
+
+ For example, if identity shifts/scales of zeros and ones are provided, this should be ``False``.
+ But if scales/shifts computed from the training data are used, and are thus in dataset units,
+ this should be ``True``.
+ out_field: the output field; defaults to ``field``.
+ """
+
+ field: str
+ out_field: str
+ scales_trainble: bool
+ shifts_trainable: bool
+ has_scales: bool
+ has_shifts: bool
+
+ def __init__(
+ self,
+ field: str,
+ num_types: int,
+ irreps_in,
+ shifts: Optional[torch.Tensor],
+ scales: Optional[torch.Tensor],
+ out_field: Optional[str] = None,
+ scales_trainable: bool = False,
+ shifts_trainable: bool = False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+ super().__init__()
+ self.num_types = num_types
+ self.field = field
+ self.out_field = f"shifted_{field}" if out_field is None else out_field
+ self.irreps_in = irreps_in
+ self.num_scalar = 0
+ self.shift_index = []
+ self.scale_index = []
+ self.dtype = dtype
+ self.device = device
+
+ start = 0
+ start_scalar = 0
+ for mul, ir in irreps_in:
+ # only the scalar irreps can be shifted
+ # all the irreps can be scaled
+ if str(ir) == "0e":
+ self.num_scalar += mul
+ self.shift_index += list(range(start_scalar, start_scalar + mul))
+ start_scalar += mul
+ else:
+ self.shift_index += [-1] * mul * ir.dim
+ for _ in range(mul):
+ self.scale_index += [start] * ir.dim
+ start += 1
+
+ self.shift_index = torch.as_tensor(self.shift_index, dtype=torch.long, device=device)
+ self.scale_index = torch.as_tensor(self.scale_index, dtype=torch.long, device=device)
+
+ self.has_shifts = shifts is not None
+ if shifts is not None:
+ shifts = torch.as_tensor(shifts, dtype=self.dtype, device=device)
+ if len(shifts.reshape([-1])) == 1:
+ shifts = torch.ones(num_types, self.num_scalar, dtype=dtype, device=device) * shifts
+ assert shifts.shape == (num_types,self.num_scalar), f"Invalid shape of shifts {shifts}"
+ self.shifts_trainable = shifts_trainable
+ if shifts_trainable:
+ self.shifts = torch.nn.Parameter(shifts)
+ else:
+ self.register_buffer("shifts", shifts)
+
+ self.has_scales = scales is not None
+ if scales is not None:
+ scales = torch.as_tensor(scales, dtype=torch.get_default_dtype())
+ if len(scales.reshape([-1])) == 1:
+ scales = torch.ones(num_types, self.irreps_in.num_irreps, dtype=dtype, device=device) * scales
+ assert scales.shape == (num_types,self.irreps_in.num_irreps), f"Invalid shape of scales {scales}"
+ self.scales_trainable = scales_trainable
+ if scales_trainable:
+ self.scales = torch.nn.Parameter(scales)
+ else:
+ self.register_buffer("scales", scales)
+
+ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None):
+ self.has_scales = scales is not None or self.has_scales
+ if scales is not None:
+ assert scales.shape == (self.num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}"
+ if self.scales_trainable:
+ self.scales = torch.nn.Parameter(scales)
+ else:
+ self.register_buffer("scales", scales)
+
+ self.has_shifts = shifts is not None or self.has_shifts
+ if shifts is not None:
+ assert shifts.shape == (self.num_types, self.num_scalar), f"Invalid shape of shifts {shifts}"
+ if self.shifts_trainable:
+ self.shifts = torch.nn.Parameter(shifts)
+ else:
+ self.register_buffer("shifts", shifts)
+
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+
+ if not (self.has_scales or self.has_shifts):
+ return data
+
+ species_idx = data[AtomicDataDict.ATOM_TYPE_KEY].flatten()
+ in_field = data[self.field]
+ assert len(in_field) == len(
+ species_idx
+ ), "in_field doesnt seem to have correct per-atom shape"
+ if self.has_scales:
+ in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field
+ if self.has_shifts:
+ shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar)
+ in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]
+ data[self.out_field] = in_field
+ return data
+
+
+@compile_mode("script")
+class E3ElementLinear(torch.nn.Module):
+ """Sum edgewise energies.
+ Includes optional per-species-pair edgewise energy scales.
+ """
+
+ weight_numel: int
+
+ def __init__(
+ self,
+ irreps_in: o3.Irreps,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+ super(E3ElementLinear, self).__init__()
+ self.irreps_in = irreps_in
+ self.num_scalar = 0
+ self.device = device
+ self.dtype = dtype
+ self.shift_index = []
+ self.scale_index = []
+
+ count_scales= 0
+ count_shift = 0
+ for mul, ir in irreps_in:
+ if str(ir) == "0e":
+ self.num_scalar += mul
+ self.shift_index += list(range(count_shift, count_shift + mul))
+ count_shift += mul
+ else:
+ self.shift_index += [-1] * mul * ir.dim
+
+ for _ in range(mul):
+ self.scale_index += [count_scales] * ir.dim
+ count_scales += 1
+
+ self.shift_index = torch.as_tensor(self.shift_index, dtype=torch.int64, device=self.device)
+ self.scale_index = torch.as_tensor(self.scale_index, dtype=torch.int64, device=self.device)
+
+ self.weight_numel = irreps_in.num_irreps + self.num_scalar
+ assert count_scales + count_shift == self.weight_numel
+ self.num_scales = count_scales
+ self.num_shifts = count_shift
+
+ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]=None):
+
+ scales = weights[:, :self.num_scales] if weights is not None else None
+ if weights is not None:
+ if weights.shape[1] > self.num_scales:
+ shifts = weights[:, self.num_scales:]
+ else:
+ shifts = None
+ else:
+ shifts = None
+
+ if scales is not None:
+ assert len(scales) == len(
+ x
+ ), "in_field doesnt seem to have correct shape as scales"
+ x = scales[:,self.scale_index].reshape(x.shape[0], -1) * x
+ else:
+ x = x
+
+ if shifts is not None:
+ assert len(shifts) == len(
+ x
+ ), "in_field doesnt seem to have correct shape as shifts"
+
+ # bias = torch.zeros_like(x)
+ # bias[:, self.shift_index.ge(0)] = shifts[:,self.shift_index[self.shift_index.ge(0)]].reshape(-1, self.num_scalar)
+ # x = x + bias
+ x[:, self.shift_index.ge(0)] = shifts[:,self.shift_index[self.shift_index.ge(0)]].reshape(-1, self.num_scalar) + x[:, self.shift_index.ge(0)]
+ else:
+ x = x
+
+ return x
\ No newline at end of file
diff --git a/dptb/nn/sktb/__init__.py b/dptb/nn/sktb/__init__.py
new file mode 100644
index 00000000..d7a1eeef
--- /dev/null
+++ b/dptb/nn/sktb/__init__.py
@@ -0,0 +1,11 @@
+from .hopping import HoppingFormula
+from .onsite import OnsiteFormula
+from .bondlengthDB import bond_length_list
+
+
+
+__all__ = [
+ 'HoppingFormula',
+ 'OnsiteFormula',
+ 'bond_length_list',
+]
\ No newline at end of file
diff --git a/dptb/nn/sktb/bondlengthDB.py b/dptb/nn/sktb/bondlengthDB.py
new file mode 100644
index 00000000..029b879e
--- /dev/null
+++ b/dptb/nn/sktb/bondlengthDB.py
@@ -0,0 +1,36 @@
+# Onsite energies database, loaded from GAPW lda potentials. stored as
+# A dictionary of dictionaries. The first dictionary is the element name, and the
+# second dictionary is the orbital name. The orbital name is the key, and the value is the onsite energy.
+import torch
+
+#
+# Contains the elements as follows:
+
+# AtomSymbol=[
+# 'H', 'He',
+# 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
+# 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar',
+# 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr',
+# 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', , 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
+# 'Cs', 'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Rn'
+# ]
+
+element = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
+ "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr",
+ "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Lu", "Hf", "Ta",
+ "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "Ra", "Th"]
+
+bond_length_list = torch.tensor([1.6,1.4,5.0,3.4,3.0,3.2,3.4,3.1,2.7,3.2,5.9,5.0,5.9,4.4,4.0,3.9,
+ 3.8,4.5,6.5,4.9,5.1,4.2,4.3,4.7,3.6,3.7,3.3,3.7,5.2,4.6,5.9,4.5,4.4,
+ 4.5,4.3,4.8,9.1,6.9,5.7,5.2,5.2,4.3,4.1,4.1,4.0,4.4,6.5,5.4,4.8,4.7,
+ 5.2,5.2,6.2,5.2,10.6,7.7,7.4,5.9,5.2,4.8,4.2,4.2,4.0,3.9,3.8,4.8,6.7,
+ 7.3,5.7,5.8])
+
+bond_length = {
+ 'H': 1.6, 'He': 1.4, 'Li': 5.0, 'Be': 3.4, 'B': 3.0, 'C': 3.2, 'N': 3.4, 'O': 3.1, 'F': 2.7, 'Ne': 3.2, 'Na': 5.9, 'Mg': 5.0,
+ 'Al': 5.9, 'Si': 4.4, 'P': 4.0, 'S': 3.9, 'Cl': 3.8, 'Ar': 4.5, 'K': 6.5, 'Ca': 4.9, 'Sc': 5.1, 'Ti': 4.2, 'V': 4.3, 'Cr': 4.7,
+ 'Mn': 3.6, 'Fe': 3.7, 'Co': 3.3, 'Ni': 3.7, 'Cu': 5.2, 'Zn': 4.6, 'Ga': 5.9, 'Ge': 4.5, 'As': 4.4, 'Se': 4.5, 'Br': 4.3, 'Kr': 4.8,
+ 'Rb': 9.1, 'Sr': 6.9, 'Y': 5.7, 'Zr': 5.2, 'Nb': 5.2, 'Mo': 4.3, 'Tc': 4.1, 'Ru': 4.1, 'Rh': 4.0, 'Pd': 4.4, 'Ag': 6.5, 'Cd': 5.4,
+ 'In': 4.8, 'Sn': 4.7, 'Sb': 5.2, 'Te': 5.2, 'I': 6.2, 'Xe': 5.2, 'Cs': 10.6, 'Ba': 7.7, 'La': 7.4, 'Lu': 5.9, 'Hf': 5.2, 'Ta': 4.8,
+ 'W': 4.2, 'Re': 4.2, 'Os': 4.0, 'Ir': 3.9, 'Pt': 3.8, 'Au': 4.8, 'Hg': 6.7, 'Tl': 7.3, 'Pb': 5.7, 'Bi': 5.8, 'Po': 5.5, 'Ra': 7.0,
+ 'Th': 6.2}
diff --git a/dptb/nn/sktb/hopping.py b/dptb/nn/sktb/hopping.py
new file mode 100644
index 00000000..9d4fe0e4
--- /dev/null
+++ b/dptb/nn/sktb/hopping.py
@@ -0,0 +1,222 @@
+# define the integrals formula.
+import torch
+from abc import ABC, abstractmethod
+from dptb.nn.sktb.bondlengthDB import bond_length_list
+
+class BaseHopping(ABC):
+ def __init__(self) -> None:
+ pass
+
+ @abstractmethod
+ def get_skhij(self, rij, **kwargs):
+ '''This is a wrap function for a self-defined formula of sk integrals. one can easily modify it into whatever form they want.
+
+ Returns
+ -------
+ The function defined by type is called to cal skhij and returned.
+
+ '''
+ pass
+
+class HoppingFormula(BaseHopping):
+ num_paras_dict = {
+ 'varTang96': 4,
+ 'powerlaw': 2,
+ 'NRL0': 4,
+ "NRL1": 4,
+ 'custom': None,
+ }
+
+ def __init__(self, functype='varTang96',overlap=False) -> None:
+ super(HoppingFormula, self).__init__()
+ # one can modify this by add his own formula with the name functype to deifine num of pars.
+ self.overlap = overlap
+ if functype == 'varTang96':
+ assert hasattr(self, 'varTang96')
+
+ elif functype == 'powerlaw':
+ assert hasattr(self, 'powerlaw')
+
+ elif functype in ['NRL0', "NRL1"]:
+ assert hasattr(self, 'NRL_HOP')
+ if overlap:
+ assert hasattr(self, 'NRL_OVERLAP')
+
+ elif functype =='custom':
+ # the functype custom, is for user to define their own formula.
+ # just modify custom to the name of your formula.
+ # and define the funnction self.custom(rij, paraArray, **kwargs)
+ assert hasattr(self, 'custom')
+ else:
+ raise ValueError('No such formula')
+
+ self.functype = functype
+ self.num_paras = self.num_paras_dict[functype]
+
+
+ def get_skhij(self, rij, **kwargs):
+ '''This is a wrap function for a self-defined formula of sk integrals. one can easily modify it into whatever form they want.
+
+ Returns
+ -------
+ The function defined by functype is called to cal skhij and returned.
+
+ '''
+
+ if self.functype == 'varTang96':
+ return self.varTang96(rij=rij, **kwargs)
+ elif self.functype == 'powerlaw':
+ return self.powerlaw(rij=rij, **kwargs)
+ elif self.functype.startswith('NRL'):
+ return self.NRL_HOP(rij=rij, **kwargs)
+ else:
+ raise ValueError('No such formula')
+
+ def get_sksij(self,rij,**kwargs):
+ '''This is a wrap function for a self-defined formula of sk overlap. one can easily modify it into whatever form they want.
+
+ Returns
+ -------
+ The function defined by functype is called to cal sk sij and returned.
+
+ '''
+ assert self.overlap, 'overlap is False, no overlap function is defined.'
+
+ if self.functype == 'NRL0':
+ return self.NRL_OVERLAP0(rij=rij, **kwargs)
+ if self.functype == 'NRL1':
+ return self.NRL_OVERLAP1(rij=rij, **kwargs)
+ elif self.functype == "powerlaw":
+ return self.powerlaw(rij=rij, **kwargs)
+ elif self.functype == "varTang96":
+ return self.varTang96(rij=rij, **kwargs)
+ else:
+ raise ValueError('No such formula')
+
+
+ def varTang96(self, rij: torch.Tensor, paraArray: torch.Tensor, rs:torch.Tensor = torch.tensor(6), w:torch.Tensor = 0.1, **kwargs):
+ """> This function calculates the value of the variational form of Tang et al 1996. without the
+ environment dependent
+
+ $$ h(rij) = \alpha_1 * (rij)^(-\alpha_2) * exp(-\alpha_3 * (rij)^(\alpha_4))$$
+
+ Parameters
+ ----------
+ rij : torch.Tensor([N, 1]/[N])
+ the bond length vector, have the same length of the bond index vector.
+ paraArray : torch.Tensor([N, ..., 4])
+ The parameters for computing varTang96's type hopping integrals, the first dimension should have the
+ same length of the bond index vector, while the last dimenion if 4, which is the number of parameters
+ for each varTang96's type formula.
+ rcut : torch.Tensor, optional
+ cut-off by half at which value, by default torch.tensor(6)
+ w : torch.Tensor, optional
+ the decay factor, the larger the smoother, by default 0.1
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+
+ rij = rij.reshape(-1)
+ assert paraArray.shape[-1] == 4 and paraArray.shape[0] == len(rij), 'paraArray should be a 2d tensor with the last dimenion if 4, which is the number of parameters for each varTang96\'s type formula.'
+ alpha1, alpha2, alpha3, alpha4 = paraArray[..., 0], paraArray[..., 1].abs(), paraArray[..., 2].abs(), paraArray[..., 3].abs()
+ shape = [-1]+[1] * (len(alpha1.shape)-1)
+ rij = rij.reshape(shape)
+ return alpha1 * rij**(-alpha2) * torch.exp(-alpha3 * rij**alpha4)/(1+torch.exp((rij-rs)/w))
+
+ def powerlaw(self, rij, paraArray, r0:torch.Tensor, rs:torch.Tensor = torch.tensor(6), w:torch.Tensor = 0.1, **kwargs):
+ """> This function calculates the value of the variational form of Tang et al 1996. without the
+ environment dependent
+
+ $$ h(rij) = \alpha_1 * (rij / r_ij0)^(\lambda + \alpha_2)
+ """
+
+ #alpha1, alpha2, alpha3, alpha4 = paraArray[:, 0], paraArray[:, 1]**2, paraArray[:, 2]**2, paraArray[:, 3]**2
+ alpha1, alpha2 = paraArray[..., 0], paraArray[..., 1].abs()
+ #[N, n_op]
+ shape = [-1]+[1] * (len(alpha1.shape)-1)
+ # [-1, 1]
+ rij = rij.reshape(shape)
+ r0 = r0.reshape(shape)
+
+ r0 = r0 / 1.8897259886
+
+ return alpha1 * (r0/rij)**(1 + alpha2) / (1+torch.exp((rij-rs)/w))
+
+ def NRL_HOP(self, rij, paraArray, rc:torch.Tensor = torch.tensor(6), w:torch.Tensor = 0.1, **kwargs):
+ """
+ This function calculates the SK integral value of the form of NRL-TB
+
+ H_{ll'u} = (a + b R + c R^2)exp(-d^2 R) f(R)
+ a,b,c,d are the parameters, R is r_ij
+
+ f(r_ij) = [1+exp((r_ij-rcut+5w)/w)]^-1; (r_ij < rcut)
+ = 0; (r_ij >= rcut)
+
+ """
+ rij = rij.reshape(-1)
+ a, b, c, d = paraArray[..., 0], paraArray[..., 1], paraArray[..., 2], paraArray[..., 3]
+ shape = [-1]+[1] * (len(a.shape)-1)
+ rij = rij.reshape(shape)
+ f_rij = 1/(1+torch.exp((rij-rc+5*w)/w))
+ f_rij[rij>=rc] = 0.0
+
+ return (a + b * rij + c * rij**2) * torch.exp(-d**2 * rij)*f_rij
+
+ def NRL_OVERLAP0(self, rij, paraArray, paraconst, rc:torch.float32 = torch.tensor(6), w:torch.float32 = 0.1, **kwargs):
+ """
+ This function calculates the Overlap value of the form of NRL-TB
+
+ S_{ll'u} = (delta_ll' + a R + b R^2 + c R^3)exp(-d^2 R) f(R)
+ a,b,c,d are the parameters, R is r_ij
+
+ f(r_ij) = [1+exp((r_ij-rcut+5w)/w)]^-1; (r_ij < rcut)
+ = 0; (r_ij >= rcut)
+ # delta
+ """
+
+ assert paraArray.shape[:-1] == paraconst.shape, 'paraArray and paraconst should have the same shape except the last dimenion.'
+ rij = rij.reshape(-1)
+ assert len(rij) == len(paraArray), 'rij and paraArray should have the same length.'
+
+ a, b, c, d = paraArray[..., 0], paraArray[..., 1], paraArray[..., 2], paraArray[..., 3]
+ shape = [-1]+[1] * (len(a.shape)-1)
+ rij = rij.reshape(shape)
+
+ f_rij = 1/(1+torch.exp((rij-rc+5*w)/w))
+ f_rij[rij>=rc] = 0.0
+
+ return (a + b * rij + c * rij**2) * torch.exp(-d**2 * rij)*f_rij
+
+ def NRL_OVERLAP1(self, rij, paraArray, paraconst, rc:torch.float32 = torch.tensor(6), w:torch.float32 = 0.1, **kwargs):
+ """
+ This function calculates the Overlap value of the form of NRL-TB
+
+ S_{ll'u} = (delta_ll' + a R + b R^2 + c R^3)exp(-d^2 R) f(R)
+ a,b,c,d are the parameters, R is r_ij
+
+ f(r_ij) = [1+exp((r_ij-rcut+5w)/w)]^-1; (r_ij < rcut)
+ = 0; (r_ij >= rcut)
+ # delta
+ """
+
+ assert paraArray.shape[:-1] == paraconst.shape, 'paraArray and paraconst should have the same shape except the last dimenion.'
+ rij = rij.reshape(-1)
+ assert len(rij) == len(paraArray), 'rij and paraArray should have the same length.'
+
+ a, b, c, d = paraArray[..., 0], paraArray[..., 1], paraArray[..., 2], paraArray[..., 3]
+ delta_ll = paraconst
+ shape = [-1]+[1] * (len(a.shape)-1)
+ rij = rij.reshape(shape)
+
+ f_rij = 1/(1+torch.exp((rij-rc+5*w)/w))
+ f_rij[rij>=rc] = 0.0
+
+ return (delta_ll + a * rij + b * rij**2 + c * rij**3) * torch.exp(-d**2 * rij)*f_rij
+
+ @classmethod
+ def num_params(cls, funtype):
+ return cls.num_paras_dict[funtype]
+
\ No newline at end of file
diff --git a/dptb/nn/sktb/onsite.py b/dptb/nn/sktb/onsite.py
new file mode 100644
index 00000000..032a4be1
--- /dev/null
+++ b/dptb/nn/sktb/onsite.py
@@ -0,0 +1,154 @@
+# define the integrals formula.
+import torch as th
+import torch
+from typing import List, Union
+from abc import ABC, abstractmethod
+from torch_runstats.scatter import scatter
+from dptb.nn.sktb.onsiteDB import onsite_energy_database
+from dptb.data.transforms import OrbitalMapper
+
+
+class BaseOnsite(ABC):
+ def __init__(self) -> None:
+ pass
+
+ @abstractmethod
+ def get_skEs(self, **kwargs):
+ '''This is a wrap function for a self-defined formula of onsite energies. one can easily modify it into whatever form they want.
+
+ Returns
+ -------
+ The function defined by type is called to cal onsite energies and returned.
+
+ '''
+ pass
+
+
+class OnsiteFormula(BaseOnsite):
+ num_paras_dict = {
+ 'uniform': 1,
+ 'none': 0,
+ 'strain': 0,
+ "NRL": 4,
+ "custom": None,
+ }
+
+ def __init__(
+ self,
+ idp: Union[OrbitalMapper, None]=None,
+ functype='none',
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu")) -> None:
+ super().__init__()
+ if functype in ['none', 'strain']:
+ pass
+ elif functype == 'uniform':
+ assert hasattr(self, 'uniform')
+
+ elif functype == 'NRL':
+ assert hasattr(self, 'NRL')
+
+ elif functype == 'custom':
+ assert hasattr(self, 'custom')
+ else:
+ raise ValueError('No such formula')
+
+ self.functype = functype
+ self.num_paras = self.num_paras_dict[functype]
+
+ self.idp = idp
+ if self.functype in ["uniform", "none", "strain"]:
+ self.E_base = torch.zeros(self.idp.num_types, self.idp.n_onsite_Es, dtype=dtype, device=device)
+ for asym, idx in self.idp.chemical_symbol_to_type.items():
+ self.E_base[idx] = torch.zeros(self.idp.n_onsite_Es)
+ for ot in self.idp.basis[asym]:
+ fot = self.idp.basis_to_full_basis[asym][ot]
+ self.E_base[idx][self.idp.skonsite_maps[fot]] = onsite_energy_database[asym][ot]
+
+ def get_skEs(self, **kwargs):
+ if self.functype == 'uniform':
+ return self.uniform(**kwargs)
+ if self.functype == 'NRL':
+ return self.NRL(**kwargs)
+ if self.functype in ['none', 'strain']:
+ return self.none(**kwargs)
+
+ def none(self, atomic_numbers: torch.Tensor, **kwargs):
+ """The none onsite function, the energy output is directly loaded from the onsite Database.
+ Parameters
+ ----------
+ atomic_numbers : torch.Tensor(N)
+ The atomic number list.
+
+ Returns
+ -------
+ torch.Tensor(N, n_orb)
+ the onsite energies by composing results from nn and ones from database.
+ """
+ atomic_numbers = atomic_numbers.reshape(-1)
+
+ idx = self.idp.transform_atom(atomic_numbers)
+
+ return self.E_base[idx]
+
+ def uniform(self, atomic_numbers: torch.Tensor, nn_onsite_paras: torch.Tensor, **kwargs):
+ """The uniform onsite function, that have the same onsite energies for one specific orbital of a atom type.
+
+ Parameters
+ ----------
+ atomic_numbers : torch.Tensor(N) or torch.Tensor(N,1)
+ The atomic number list.
+ nn_onsite_paras : torch.Tensor(N_atom_type, n_orb)
+ The nn fitted parameters for onsite energies.
+
+ Returns
+ -------
+ torch.Tensor(N, n_orb)
+ the onsite energies by composing results from nn and ones from database.
+ """
+ atomic_numbers = atomic_numbers.reshape(-1)
+ if nn_onsite_paras.shape[-1] == 1:
+ nn_onsite_paras = nn_onsite_paras.squeeze(-1)
+
+ assert len(nn_onsite_paras) == self.E_base.shape[0]
+
+ idx = self.idp.transform_atom(atomic_numbers)
+
+ return nn_onsite_paras[idx] + self.none(atomic_numbers=atomic_numbers)
+
+
+ def NRL(self, atomic_numbers, onsitenv_index, onsitenv_length, nn_onsite_paras, rc:th.float32 = th.tensor(6), w:th.float32 = 0.1, lda=1.0, **kwargs):
+ """ This is NRL-TB formula for onsite energies.
+
+ rho_i = \sum_j exp(- lda**2 r_ij) f(r_ij)
+
+ E_il = a_l + b_l rho_i^(2/3) + c_l rho_i^(4/3) + d_l rho_i^2
+
+ f(r_ij) = [1+exp((r_ij-rcut+5w)/w)]^-1; (r_ij < rcut)
+ = 0; (r_ij >= rcut)
+ Parameters
+ ----------
+ onsitenv_index: torch.LongTensor
+ env index shaped as [2, N]
+ onsitenv_length: torch.Tensor
+ env index shaped as [N] or [N,1]
+ nn_onsite_paras: torch.Tensor
+ [N, n_orb, 4]
+ rcut: float
+ the cutoff radius for onsite energies.
+ w: float
+ the decay for the cutoff smoth function.
+ lda: float
+ the decay for the calculateing rho.
+ """
+ atomic_numbers = atomic_numbers.reshape(-1)
+ idx = self.idp.transform_atom(atomic_numbers)
+ nn_onsite_paras = nn_onsite_paras[idx]
+ r_ijs = onsitenv_length.view(-1) # [N]
+ exp_rij = th.exp(-lda**2 * r_ijs)
+ f_rij = 1/(1+th.exp((r_ijs-rc+5*w)/w))
+ f_rij[r_ijs>=rc] = 0.0
+ rho_i = scatter(src=exp_rij * f_rij, index=onsitenv_index[0], dim=0, reduce="sum").unsqueeze(1) # [N_atom, 1]
+ a_l, b_l, c_l, d_l = nn_onsite_paras[:,:,0], nn_onsite_paras[:,:,1], nn_onsite_paras[:,:,2], nn_onsite_paras[:,:,3]
+ E_il = a_l + b_l * rho_i**(2/3) + c_l * rho_i**(4/3) + d_l * rho_i**2 # [N_atom, n_orb]
+ return E_il # [N_atom, n_orb]
\ No newline at end of file
diff --git a/dptb/nnsktb/onsiteDB_eV.py b/dptb/nn/sktb/onsiteDB.py
similarity index 100%
rename from dptb/nnsktb/onsiteDB_eV.py
rename to dptb/nn/sktb/onsiteDB.py
diff --git a/dptb/nn/type_encode/__init__.py b/dptb/nn/type_encode/__init__.py
new file mode 100644
index 00000000..d3d7c199
--- /dev/null
+++ b/dptb/nn/type_encode/__init__.py
@@ -0,0 +1,5 @@
+from .one_hot import OneHotAtomEncoding
+
+__all__ = [
+ OneHotAtomEncoding,
+]
diff --git a/dptb/nn/type_encode/one_hot.py b/dptb/nn/type_encode/one_hot.py
new file mode 100644
index 00000000..c0bd5eda
--- /dev/null
+++ b/dptb/nn/type_encode/one_hot.py
@@ -0,0 +1,32 @@
+import torch
+import torch.nn.functional
+from dptb.data import AtomicDataDict
+
+class OneHotAtomEncoding(torch.nn.Module):
+ """Copmute a one-hot floating point encoding of atoms' discrete atom types.
+
+ Args:
+ set_features: If ``True`` (default), ``node_features`` will be set in addition to ``node_attrs``.
+ """
+
+ num_types: int
+ set_features: bool
+
+ def __init__(
+ self,
+ num_types: int,
+ set_features: bool = True
+ ):
+ super().__init__()
+ self.num_types = num_types
+ self.set_features = set_features
+
+ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
+ type_numbers = data[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1)
+ one_hot = torch.nn.functional.one_hot(
+ type_numbers, num_classes=self.num_types
+ ).to(device=type_numbers.device, dtype=data[AtomicDataDict.POSITIONS_KEY].dtype)
+ data[AtomicDataDict.NODE_ATTRS_KEY] = one_hot
+ if self.set_features:
+ data[AtomicDataDict.NODE_FEATURES_KEY] = one_hot
+ return data
diff --git a/dptb/nn/type_encode/type_embedding.py b/dptb/nn/type_encode/type_embedding.py
new file mode 100644
index 00000000..2c0ad1f8
--- /dev/null
+++ b/dptb/nn/type_encode/type_embedding.py
@@ -0,0 +1,3 @@
+"""write the node and edge embedding for descriptors
+"""
+
diff --git a/dptb/nnops/base_tester.py b/dptb/nnops/base_tester.py
index 5799e491..e09efb77 100644
--- a/dptb/nnops/base_tester.py
+++ b/dptb/nnops/base_tester.py
@@ -5,6 +5,7 @@
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from future.utils import with_metaclass
+from typing import Union
from dptb.utils.constants import dtype_dict
from dptb.plugins.base_plugin import PluginUser
@@ -65,6 +66,56 @@ def calc(self, **data):
def test(self) -> None:
pass
+class BaseTester(with_metaclass(ABCMeta, PluginUser)):
+
+ def __init__(
+ self,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ) -> None:
+ super(BaseTester, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+ self.dtype = dtype
+ self.device = device
+
+ ''' Here is for plugins.
+ plugins:
+ - iteration: events after every batch training iteration.
+ - update: the updates of model paras including networks and optimiser, such as leaning rate, etc. after the batch training.
+ - batch: events before batch training.
+ - epoch: events after epoch batch training
+ The difference b/w iteration and update the parameters, iteration takes in the batch output, loss etc., while update takes in model itself.
+ '''
+ self.iter = 1
+ self.ep = 1
+
+ def run(self):
+ for q in self.plugin_queues.values():
+ '''对四个事件调用序列进行最小堆排序。'''
+ heapq.heapify(q)
+
+ self.epoch()
+ # run plugins of epoch events.
+ self.call_plugins(queue_name='epoch', time=i)
+ self.lr_scheduler.step() # modify the lr at each epoch (should we add it to pluggins so we could record the lr scheduler process?)
+ self.ep += 1
+
+
+ @abstractmethod
+ def iteration(self, **data):
+ '''
+ conduct one step forward computation, used in train, test and validation.
+ '''
+ pass
+
+ @abstractmethod
+ def epoch(self) -> None:
+ """define a training iteration process
+ """
+ pass
+
if __name__ == '__main__':
a = [1, 2, 3]
diff --git a/dptb/nnops/base_trainer.py b/dptb/nnops/base_trainer.py
index 1bf9e363..70b0ee46 100644
--- a/dptb/nnops/base_trainer.py
+++ b/dptb/nnops/base_trainer.py
@@ -3,7 +3,7 @@
import logging
from dptb.utils.tools import get_lr_scheduler, j_must_have, get_optimizer
from abc import ABCMeta, abstractmethod
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from future.utils import with_metaclass
from dptb.utils.constants import dtype_dict
from dptb.plugins.base_plugin import PluginUser
@@ -11,13 +11,20 @@
log = logging.getLogger(__name__)
+class BaseTrainer(with_metaclass(ABCMeta, PluginUser)):
-class Trainer(with_metaclass(ABCMeta, PluginUser)):
+ def __init__(
+ self,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ ) -> None:
+ super(BaseTrainer, self).__init__()
+
+ if isinstance(dtype, str):
+ dtype = dtype_dict[dtype]
+ self.dtype = dtype
+ self.device = device
- def __init__(self, jdata) -> None:
- super(Trainer, self).__init__()
- self.dtype = jdata["common_options"]["dtype"]
- self.device = jdata["common_options"]["device"]
''' Here is for plugins.
plugins:
- iteration: events after every batch training iteration.
@@ -26,21 +33,13 @@ def __init__(self, jdata) -> None:
- epoch: events after epoch batch training
The difference b/w iteration and update the parameters, iteration takes in the batch output, loss etc., while update takes in model itself.
'''
- self.iteration = 1
- self.epoch = 1
-
-
+ self.iter = 1
+ self.ep = 1
@abstractmethod
- def _init_param(self, jdata):
-
- pass
-
- @abstractmethod
- def build(self):
- '''
- init the model
- '''
+ def restart(self, checkpoint):
+ """init trainer from disk
+ """
pass
def run(self, epochs=1):
@@ -48,24 +47,30 @@ def run(self, epochs=1):
'''对四个事件调用序列进行最小堆排序。'''
heapq.heapify(q)
- for i in range(self.epoch, epochs + 1):
- self.train()
+ for i in range(self.ep, epochs + 1):
+ self.epoch()
# run plugins of epoch events.
self.call_plugins(queue_name='epoch', time=i)
- self.lr_scheduler.step() # modify the lr at each epoch (should we add it to pluggins so we could record the lr scheduler process?)
+
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step(self.stats["train_loss"]["epoch_mean"])
+ else:
+ self.lr_scheduler.step() # modify the lr at each epoch (should we add it to pluggins so we could record the lr scheduler process?)
self.update()
- self.epoch += 1
+ self.ep += 1
@abstractmethod
- def calc(self, **data):
+ def iteration(self, **data):
'''
conduct one step forward computation, used in train, test and validation.
'''
pass
@abstractmethod
- def train(self) -> None:
+ def epoch(self) -> None:
+ """define a training iteration process
+ """
pass
@abstractmethod
diff --git a/dptb/nnops/loss.py b/dptb/nnops/loss.py
index 0f6cdbe1..bb894cc5 100644
--- a/dptb/nnops/loss.py
+++ b/dptb/nnops/loss.py
@@ -1,176 +1,403 @@
-import numpy as np
-import torch as th
-#import torchsort
-
-def loss_type1(criterion, eig_pred, eig_label,num_el,num_kp, band_min=0, band_max=None, spin_deg=2):
- norbs = eig_pred.shape[-1]
- nbanddft = eig_label.shape[-1]
- up_nband = min(norbs,nbanddft)
- num_val_band = int(num_el//spin_deg)
- num_k_val_band = int(num_kp * num_el // spin_deg)
- assert num_val_band <= up_nband
- if band_max is None:
- band_max = up_nband
- else:
- assert band_max <= up_nband
-
- band_min = int(band_min)
- band_max = int(band_max)
+import torch.nn as nn
+import torch
+from torch.nn.functional import mse_loss
+from dptb.utils.register import Register
+from dptb.nn.energy import Eigenvalues
+from dptb.nn.hamiltonian import E3Hamiltonian
+from typing import Any, Union, Dict
+from dptb.data import AtomicDataDict, AtomicData
+from dptb.data.transforms import OrbitalMapper
+from dptb.utils.torch_geometric import Batch
- assert band_min < band_max
- # shape of eigs [batch_size, num_kp, num_bands]
- assert len(eig_pred.shape) == 3 and len(eig_label.shape) == 3
+"""this is the register class for descriptors
- # 对齐eig_pred和eig_label
- eig_pred_cut = eig_pred[:,:,band_min:band_max]
- eig_label_cut = eig_label[:,:,band_min:band_max]
-
- batch_size, num_kp, num_bands = eig_pred_cut.shape
-
- eig_pred_cut -= eig_pred_cut.reshape(batch_size,-1).min(dim=1)[0].reshape(batch_size,1,1)
- eig_label_cut -= eig_label_cut.reshape(batch_size,-1).min(dim=1)[0].reshape(batch_size,1,1)
-
- loss = criterion(eig_pred_cut,eig_label_cut)
-
- return loss
-
-def loss_soft_sort(criterion, eig_pred, eig_label,num_el,num_kp, sort_strength=0.5, kmax=None, kmin=0, band_min=0, band_max=None, spin_deg=2, gap_penalty=False, fermi_band=0, eta=1e-2, **kwarg):
- norbs = eig_pred.shape[-1]
- nbanddft = eig_label.shape[-1]
- up_nband = min(norbs,nbanddft)
- num_val_band = int(num_el//spin_deg)
- num_k_val_band = int(num_kp * num_el // spin_deg)
- assert num_val_band <= up_nband
- if band_max is None:
- band_max = up_nband
- else:
- assert band_max <= up_nband
+all descriptors inplemendeted should be a instance of nn.Module class, and provide a forward function that
+takes AtomicData class as input, and give AtomicData class as output.
+
+"""
+class Loss:
+ _register = Register()
+
+ def register(target):
+ return Loss._register.register(target)
- if kmax is None:
- kmax = num_kp
- else:
- assert kmax <= num_kp
+ def __new__(cls, method: str, **kwargs):
+ if method in Loss._register.keys():
+ return Loss._register[method](**kwargs)
+ else:
+ raise Exception(f"Loss method: {method} is not registered!")
+
+@Loss.register("eigvals")
+class EigLoss(nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ overlap: bool=False,
+ diff_on: bool=False,
+ eout_weight: float=0.01,
+ diff_weight: float=0.01,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+ super(EigLoss, self).__init__()
+ self.loss = nn.MSELoss()
+ self.device = device
+ self.diff_on = diff_on
+ self.eout_weight = eout_weight
+ self.diff_weight = diff_weight
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ if not overlap:
+ self.eigenvalue = Eigenvalues(
+ idp=self.idp,
+ h_edge_field = AtomicDataDict.EDGE_FEATURES_KEY,
+ h_node_field = AtomicDataDict.NODE_FEATURES_KEY,
+ h_out_field = AtomicDataDict.HAMILTONIAN_KEY,
+ out_field = AtomicDataDict.ENERGY_EIGENVALUE_KEY,
+ s_edge_field = None,
+ s_node_field = None,
+ s_out_field = None,
+ dtype=dtype,
+ device=device,
+ )
+ else:
+ self.eigenvalue = Eigenvalues(
+ idp=self.idp,
+ h_edge_field = AtomicDataDict.EDGE_FEATURES_KEY,
+ h_node_field = AtomicDataDict.NODE_FEATURES_KEY,
+ h_out_field = AtomicDataDict.HAMILTONIAN_KEY,
+ out_field = AtomicDataDict.ENERGY_EIGENVALUE_KEY,
+ s_edge_field = AtomicDataDict.EDGE_OVERLAP_KEY,
+ s_node_field = AtomicDataDict.NODE_OVERLAP_KEY,
+ s_out_field = AtomicDataDict.OVERLAP_KEY,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.overlap = overlap
- band_min = int(band_min)
- band_max = int(band_max)
+ def forward(
+ self,
+ data: AtomicDataDict,
+ ref_data: AtomicDataDict,
+ ):
+
+ total_loss = 0.
- assert band_min < band_max
- # shape of eigs [batch_size, num_kp, num_bands]
- assert len(eig_pred.shape) == 3 and len(eig_label.shape) == 3
+ data = Batch.from_dict(data)
+ ref_data = Batch.from_dict(ref_data)
- eig_pred_cut = eig_pred[:,kmin:kmax,band_min:band_max]
- eig_label_cut = eig_label[:,kmin:kmax,band_min:band_max]
- batch_size, num_kp, num_bands = eig_pred_cut.shape
+ datalist = data.to_data_list()
+ ref_datalist = ref_data.to_data_list()
- eig_pred_cut -= eig_pred_cut.reshape(batch_size,-1).min(dim=1)[0].reshape(batch_size,1,1)
- eig_label_cut -= eig_label_cut.reshape(batch_size,-1).min(dim=1)[0].reshape(batch_size,1,1)
+ for data, ref_data in zip(datalist, ref_datalist):
+ data = self.eigenvalue(AtomicData.to_AtomicDataDict(data))
+ ref_data = AtomicData.to_AtomicDataDict(ref_data)
+ if ref_data.get(AtomicDataDict.ENERGY_EIGENVALUE_KEY) is None:
+ ref_data = self.eigenvalue(ref_data)
+
+ emin, emax = ref_data.get(AtomicDataDict.ENERGY_WINDOWS_KEY, (None, None))
+ band_min, band_max = ref_data.get(AtomicDataDict.BAND_WINDOW_KEY, (0, None))
+ eig_pred = data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] # (n_kpt, n_band)
+ eig_label = ref_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] # (n_kpt, n_band_dft/n_band)
- eig_pred_cut = th.reshape(eig_pred_cut, [-1,band_max-band_min])
- eig_label_cut = th.reshape(eig_label_cut, [-1,band_max-band_min])
+ norbs = eig_pred.shape[-1]
+ nbanddft = eig_label.shape[-1]
+ num_kp = eig_label.shape[-2]
- eig_pred_soft = torchsort.soft_sort(eig_pred_cut,regularization_strength=sort_strength)
- eig_label_soft = torchsort.soft_sort(eig_label_cut,regularization_strength=sort_strength)
-
-
- eig_pred_soft = th.reshape(eig_pred_soft, [batch_size, num_kp, num_bands])
- eig_label_soft = th.reshape(eig_label_soft, [batch_size, num_kp, num_bands])
-
- loss = criterion(eig_pred_soft,eig_label_soft)
-
- if gap_penalty:
- gap1 = eig_pred_soft[:,:,fermi_band+1] - eig_pred_soft[:,:,fermi_band]
- gap2 = eig_label_soft[:,:,fermi_band+1] - eig_label_soft[:,:,fermi_band]
- loss_gap = criterion(1.0/(gap1+eta), 1.0/(gap2+eta))
-
- if num_kp > 1:
- # randon choose nk_diff kps' eigenvalues to gen Delta eig.
- # nk_diff = max(nkps//4,1)
- nk_diff = num_kp
- k_diff_i = np.random.choice(num_kp,nk_diff,replace=False)
- k_diff_j = np.random.choice(num_kp,nk_diff,replace=False)
- while (k_diff_i==k_diff_j).all():
- k_diff_j = np.random.choice(num_kp, nk_diff, replace=False)
- eig_diff_lbl = eig_label_soft[:,k_diff_i,:] - eig_label_soft[:,k_diff_j,:]
- eig_ddiff_pred = eig_pred_soft[:,k_diff_i,:] - eig_pred_soft[:,k_diff_j,:]
- loss_diff = criterion(eig_diff_lbl, eig_ddiff_pred)
+ assert num_kp == eig_pred.shape[-2]
+ up_nband = min(norbs, nbanddft)
+
+ if band_max == None:
+ band_max = up_nband
+ else:
+ assert band_max <= up_nband
+
+ band_min = int(band_min)
+ band_max = int(band_max)
+
+ assert band_min < band_max
+ assert len(eig_pred.shape) == 2 and len(eig_label.shape) == 2
+
+ # 对齐eig_pred和eig_label
+ eig_pred_cut = eig_pred[:,band_min:band_max]
+ eig_label_cut = eig_label[:,band_min:band_max]
+
+
+ num_kp, num_bands = eig_pred_cut.shape
+
+ eig_pred_cut = eig_pred_cut - eig_pred_cut.reshape(-1).min()
+ eig_label_cut = eig_label_cut - eig_label_cut.reshape(-1).min()
+
+
+ if emax != None and emin != None:
+ mask_in = eig_label_cut.lt(emax) * eig_label_cut.gt(emin)
+ mask_out = eig_label_cut.gt(emax) + eig_label_cut.lt(emin)
+ elif emax != None:
+ mask_in = eig_label_cut.lt(emax)
+ mask_out = eig_label_cut.gt(emax)
+ elif emin != None:
+ mask_in = eig_label_cut.gt(emin)
+ mask_out = eig_label_cut.lt(emin)
+ else:
+ mask_in = None
+ mask_out = None
+
+ if mask_in is not None:
+ if torch.any(mask_in).item():
+ loss = mse_loss(eig_pred_cut.masked_select(mask_in), eig_label_cut.masked_select(mask_in))
+ if torch.any(mask_out).item():
+ loss = loss + self.eout_weight * mse_loss(eig_pred_cut.masked_select(mask_out), eig_label_cut.masked_select(mask_out))
+ else:
+ loss = mse_loss(eig_pred_cut, eig_label_cut)
+
+ if self.diff_on:
+ assert num_kp >= 1
+ # randon choose nk_diff kps' eigenvalues to gen Delta eig.
+ # nk_diff = max(nkps//4,1)
+ nk_diff = num_kp
+ k_diff_i = torch.randint(0, num_kp, (nk_diff,), device=self.device)
+ k_diff_j = torch.randint(0, num_kp, (nk_diff,), device=self.device)
+ while (k_diff_i==k_diff_j).all():
+ k_diff_j = torch.randint(0, num_kp, (nk_diff,), device=self.device)
+ if mask_in is not None:
+ eig_diff_lbl = eig_label_cut.masked_fill(mask_in, 0.)[:, k_diff_i,:] - eig_label_cut.masked_fill(mask_in, 0.)[:,k_diff_j,:]
+ eig_ddiff_pred = eig_pred_cut.masked_fill(mask_in, 0.)[:,k_diff_i,:] - eig_pred_cut.masked_fill(mask_in, 0.)[:,k_diff_j,:]
+ else:
+ eig_diff_lbl = eig_label_cut[:,k_diff_i,:] - eig_label_cut[:,k_diff_j,:]
+ eig_ddiff_pred = eig_pred_cut[:,k_diff_i,:] - eig_pred_cut[:,k_diff_j,:]
+ loss_diff = mse_loss(eig_diff_lbl, eig_ddiff_pred)
+
+ loss = loss + self.diff_weight * loss_diff
+
+ total_loss += loss
+
+ return total_loss / len(datalist)
+
+@Loss.register("hamil")
+class HamilLoss(nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ overlap: bool=False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+
+ super(HamilLoss, self).__init__()
+ self.loss1 = nn.L1Loss()
+ self.loss2 = nn.MSELoss()
+ self.overlap = overlap
+ self.device = device
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
+ # mask the data
+
+ # data[AtomicDataDict.NODE_FEATURES_KEY].masked_fill(~self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY]], 0.)
+ # data[AtomicDataDict.EDGE_FEATURES_KEY].masked_fill(~self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY]], 0.)
+
+ node_mean = ref_data[AtomicDataDict.NODE_FEATURES_KEY].mean(dim=-1, keepdim=True)
+ edge_mean = ref_data[AtomicDataDict.EDGE_FEATURES_KEY].mean(dim=-1, keepdim=True)
+ node_weight = 1/((ref_data[AtomicDataDict.NODE_FEATURES_KEY]-node_mean).norm(dim=-1, keepdim=True)+1e-5)
+ edge_weight = 1/((ref_data[AtomicDataDict.EDGE_FEATURES_KEY]-edge_mean).norm(dim=-1, keepdim=True)+1e-5)
- loss = (1*loss + 1*loss_diff)/2
-
- if gap_penalty:
- loss = loss + 0.1*loss_gap
-
- return loss
-
-
-
-def loss_spectral(criterion, eig_pred, eig_label, emin, emax, num_omega=None, sigma=0.1, **kwargs):
- ''' use eigenvalues to calculate electronic spectral functions and the use the prediced and label spectral
- function to calcualted loss .
- '''
- # calculate spectral fucntion A(k,w):
- assert len(eig_pred.shape) == 3 and len(eig_label.shape) == 3
- if num_omega is None:
- num_omega = int((emax - emin)/sigma)
- omega = th.linspace(emin,emax,num_omega)
- min1 = th.min(eig_label)
- min2 = th.min(eig_pred)
- min1.detach()
- eig_label = eig_label-min1.detach()
- eig_pred = eig_pred - min2.detach()
- spectral_lbl = cal_spectral_func(eigenvalues= eig_label, omega=omega, sigma=sigma)
- spectral_pred = cal_spectral_func(eigenvalues= eig_pred, omega=omega, sigma=sigma)
- loss = criterion(spectral_lbl, spectral_pred)
+ pre = (node_weight*(data[AtomicDataDict.NODE_FEATURES_KEY]-node_mean))[self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
+ tgt = (node_weight*(ref_data[AtomicDataDict.NODE_FEATURES_KEY]-node_mean))[self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
+ onsite_loss = self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt))
+
+ pre = (edge_weight*(data[AtomicDataDict.EDGE_FEATURES_KEY]-edge_mean))[self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ tgt = (edge_weight*(ref_data[AtomicDataDict.EDGE_FEATURES_KEY]-edge_mean))[self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ hopping_loss = self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt))
+
+ if self.overlap:
+ over_mean = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY].mean(dim=-1, keepdim=True)
+ over_weight = 1/((ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]-over_mean).norm(dim=-1, keepdim=True)+1e-5)
+ pre = (over_weight*(data[AtomicDataDict.EDGE_OVERLAP_KEY]-over_mean))[self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ tgt = (over_weight*(ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]-over_mean))[self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ hopping_loss += self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt))
+
+ return hopping_loss + onsite_loss
- return loss
-def gauss(x,sig,mu=0):
- ## gaussion fucntion
- #return th.exp(-(x-mu)**2/(2*sig**2)) * (1/((2*th.pi)**0.5*sig))
- return th.exp(-(x-mu)**2/(2*sig**2))
+@Loss.register("hamil_abs")
+class HamilLossAbs(nn.Module):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ overlap: bool=False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+ super(HamilLossAbs, self).__init__()
+ self.loss1 = nn.L1Loss()
+ self.loss2 = nn.MSELoss()
+ self.overlap = overlap
+ self.device = device
-def cal_spectral_func(eigenvalues,omega,sigma=0.1):
- nsnap, nkp, nband = eigenvalues.shape
- eigs_rsp = th.reshape(eigenvalues,[nsnap * nkp * nband,1])
- omega = th.reshape(omega,[1,-1])
- nomega = omega.shape[1]
- diffmax = omega - eigs_rsp
- gaussian_weight= gauss(diffmax,sigma)
- gaussian_weight_fmt = th.reshape(gaussian_weight,[nsnap, nkp, nband, nomega])
- # eigenvalues_fmt = np.reshape(eigenvalues,[nsnap, nkp, nband, 1])
- spectral_func = th.sum(gaussian_weight_fmt,dim=2)
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
- return spectral_func
+ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
+ # mask the data
+ # data[AtomicDataDict.NODE_FEATURES_KEY].masked_fill(~self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY]], 0.)
+ # data[AtomicDataDict.EDGE_FEATURES_KEY].masked_fill(~self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY]], 0.)
+
+ pre = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
+ tgt = ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
+ onsite_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))
+ pre = data[AtomicDataDict.EDGE_FEATURES_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ tgt = ref_data[AtomicDataDict.EDGE_FEATURES_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ hopping_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))
+
+ if self.overlap:
+ pre = data[AtomicDataDict.EDGE_OVERLAP_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ tgt = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
+ overlap_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))
-def loss_proj_env(criterion, eig_pred, eig_label, ev_pred, proj_label, band_min=0, band_max=None):
- # eig_pred [nsnap, nkp, n_band_tb], eig_label [nsnap, nkp, n_band_dft]
- # ev_pred [nsnap, nkp, n_band_tb, norb_tb], ev_label [nsnap, nkp, n_band_dft, nprojorb_dft]
- # orbmap_pred [{atomtype-orbtype:index}*nsnap], orbmap_label [{atomtype-orbtype:index}*nsnap]
- # fit_band ["N-0s","B-0s"] like this
-
- norbs = eig_pred.shape[-1]
- nbanddft = eig_label.shape[-1]
- up_nband = min(norbs,nbanddft)
- if band_max is None:
- band_max = up_nband
- else:
- assert band_max <= up_nband
-
- band_min = int(band_min)
- band_max = int(band_max)
-
- nsnap, nkp, n_band_tb = eig_pred.shape
- wei = np.abs(ev_pred)**2
- wei_shp = wei[:,:,band_min:band_max,[0,3,1,2,5,8,6,7]]
- eig_pred_reshap = th.reshape(eig_pred[:,:,band_min:band_max], [nsnap,nkp, band_max - band_min,1])
- encoding_band_pred = th.sum(eig_pred_reshap * wei_shp,axis=2)
-
- eig_label_reshap = th.reshape(eig_label[:,:,band_min:band_max], [nsnap,nkp,band_max - band_min,1])
- wei_lbl_shp = proj_label[:,:,band_min:band_max]
- encoding_band_label = th.sum(eig_label_reshap * wei_lbl_shp,axis=2)
+ return (1/3) * (hopping_loss + onsite_loss + overlap_loss)
+ else:
+ return 0.5 * (onsite_loss + hopping_loss)
+
+
+class HamilLossAnalysis(object):
+ def __init__(
+ self,
+ basis: Dict[str, Union[str, list]]=None,
+ idp: Union[OrbitalMapper, None]=None,
+ overlap: bool=False,
+ dtype: Union[str, torch.dtype] = torch.float32,
+ decompose: bool = False,
+ device: Union[str, torch.device] = torch.device("cpu"),
+ **kwargs,
+ ):
+
+ super(HamilLossAnalysis, self).__init__()
+ self.overlap = overlap
+ self.device = device
+ self.decompose = decompose
+
+ if basis is not None:
+ self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
+ if idp is not None:
+ assert idp == self.idp, "The basis of idp and basis should be the same."
+ else:
+ assert idp is not None, "Either basis or idp should be provided."
+ self.idp = idp
+
+ if decompose:
+ self.e3h = E3Hamiltonian(idp=idp, decompose=decompose, overlap=False, device=device, dtype=dtype)
+ self.e3s = E3Hamiltonian(idp=idp, decompose=decompose, overlap=True, device=device, dtype=dtype)
- loss = criterion(encoding_band_pred, encoding_band_label)
+ def __call__(self, data: AtomicDataDict, ref_data: AtomicDataDict):
+ if self.decompose:
+ data = self.e3h(data)
+ ref_data = self.e3h(ref_data)
+ if self.overlap:
+ data = self.e3s(data)
+ ref_data = self.e3s(ref_data)
+
+
+ with torch.no_grad():
+ out = {}
+ err = data[AtomicDataDict.NODE_FEATURES_KEY] - ref_data[AtomicDataDict.NODE_FEATURES_KEY]
+ amp = ref_data[AtomicDataDict.NODE_FEATURES_KEY].abs()
+ mask = self.idp.mask_to_nrme[data["atom_types"].flatten()]
+ onsite = out.setdefault("onsite", {})
+ for at, tp in self.idp.chemical_symbol_to_type.items():
+ onsite_mask = mask[data["atom_types"].flatten().eq(tp)]
+ onsite_err = err[data["atom_types"].flatten().eq(tp)]
+ onsite_amp = amp[data["atom_types"].flatten().eq(tp)]
+ onsite_err = torch.stack([vec[ma] for vec, ma in zip(onsite_err, onsite_mask)])
+ onsite_amp = torch.stack([vec[ma] for vec, ma in zip(onsite_amp, onsite_mask)])
+ rmserr = (onsite_err**2).mean(dim=0).sqrt()
+ maerr = onsite_err.abs().mean(dim=0)
+ l1amp = onsite_amp.abs().mean(dim=0)
+ l2amp = (onsite_amp**2).mean(dim=0).sqrt()
+ onsite[at] = {
+ "rmse":(rmserr**2).mean().sqrt(),
+ "mae":maerr.mean(),
+ "rmse_per_block_element":rmserr,
+ "mae_per_block_element":maerr,
+ "l1amp":l1amp,
+ "l2amp":l2amp,
+ }
+
+ err = data[AtomicDataDict.EDGE_FEATURES_KEY] - ref_data[AtomicDataDict.EDGE_FEATURES_KEY]
+ amp = ref_data[AtomicDataDict.EDGE_FEATURES_KEY].abs()
+ mask = self.idp.mask_to_erme[data["edge_type"].flatten()]
+ hopping = out.setdefault("hopping", {})
+ for bt, tp in self.idp.bond_to_type.items():
+ hopping_mask = mask[data["edge_type"].flatten().eq(tp)]
+ hopping_err = err[data["edge_type"].flatten().eq(tp)]
+ hopping_amp = amp[data["edge_type"].flatten().eq(tp)]
+ hopping_err = torch.stack([vec[ma] for vec, ma in zip(hopping_err, hopping_mask)])
+ hopping_amp = torch.stack([vec[ma] for vec, ma in zip(hopping_amp, hopping_mask)])
+ rmserr = (hopping_err**2).mean(dim=0).sqrt()
+ maerr = hopping_err.abs().mean(dim=0)
+ l1amp = hopping_amp.abs().mean(dim=0)
+ l2amp = (hopping_amp**2).mean(dim=0).sqrt()
+ hopping[bt] = {
+ "rmse":(rmserr**2).mean().sqrt(),
+ "mae":maerr.mean(),
+ "rmse_per_block_element":rmserr,
+ "mae_per_block_element":maerr,
+ "l1amp":l1amp,
+ "l2amp":l2amp,
+ }
+
+ if self.overlap:
+ err = data[AtomicDataDict.EDGE_OVERLAP_KEY] - ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
+ amp = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY].abs()
+ mask = self.idp.mask_to_erme[data["edge_type"].flatten()]
+ overlap = out.setdefault("overlap", {})
+
+ for bt, tp in self.idp.bond_to_type.items():
+ hopping_mask = mask[data["edge_type"].flatten().eq(tp)]
+ hopping_err = err[data["edge_type"].flatten().eq(tp)]
+ hopping_amp = amp[data["edge_type"].flatten().eq(tp)]
+ hopping_err = torch.stack([vec[ma] for vec, ma in zip(hopping_err, hopping_mask)])
+ hopping_amp = torch.stack([vec[ma] for vec, ma in zip(hopping_amp, hopping_mask)])
+ rmserr = (hopping_err**2).mean(dim=0).sqrt()
+ maerr = hopping_err.abs().mean(dim=0)
+ l1amp = hopping_amp.abs().mean(dim=0)
+ l2amp = (hopping_amp**2).mean(dim=0).sqrt()
+
+ overlap[bt] = {
+ "rmse":(rmserr**2).mean().sqrt(),
+ "mae":maerr.mean(),
+ "rmse_per_block_element":rmserr,
+ "mae_per_block_element":maerr,
+ "l1amp":l1amp,
+ "l2amp":l2amp,
+ }
- return loss
+ return out
\ No newline at end of file
diff --git a/dptb/nnops/tester.py b/dptb/nnops/tester.py
new file mode 100644
index 00000000..56da1eb3
--- /dev/null
+++ b/dptb/nnops/tester.py
@@ -0,0 +1,77 @@
+import torch
+import logging
+from dptb.utils.tools import get_lr_scheduler, \
+get_optimizer, j_must_have
+from dptb.nnops.base_tester import BaseTester
+from typing import Union, Optional
+from dptb.data import AtomicDataset, DataLoader, AtomicData
+from dptb.nn import build_model
+from dptb.nnops.loss import Loss
+
+log = logging.getLogger(__name__)
+#TODO: complete the log output for initilizing the trainer
+
+class Tester(BaseTester):
+
+ def __init__(
+ self,
+ test_options: dict,
+ common_options: dict,
+ model: torch.nn.Module,
+ test_datasets: AtomicDataset,
+ ) -> None:
+ super(Tester, self).__init__(dtype=common_options["dtype"], device=common_options["device"])
+
+ # init the object
+ self.model = model.to(self.device)
+ self.common_options = common_options
+ self.test_options = test_options
+
+ self.test_datasets = test_datasets
+
+ self.test_loader = DataLoader(dataset=self.train_datasets, batch_size=test_options["batch_size"], shuffle=False)
+
+ # loss function
+ self.test_lossfunc = Loss(**test_options["loss_options"]["test"], **common_options, idp=self.model.hamiltonian.idp)
+
+ def iteration(self, batch):
+ '''
+ conduct one step forward computation, used in train, test and validation.
+ '''
+ self.model.eval()
+ batch = batch.to(self.device)
+
+ # record the batch_info to help reconstructing sub-graph from the batch
+ batch_info = {
+ "__slices__": batch.__slices__,
+ "__cumsum__": batch.__cumsum__,
+ "__cat_dims__": batch.__cat_dims__,
+ "__num_nodes_list__": batch.__num_nodes_list__,
+ "__data_class__": batch.__data_class__,
+ }
+
+ batch = AtomicData.to_AtomicDataDict(batch)
+
+ batch_for_loss = batch.copy() # make a shallow copy in case the model change the batch data
+ #TODO: the rescale/normalization can be added here
+ batch = self.model(batch)
+
+ #TODO: this could make the loss function unjitable since t he batchinfo in batch and batch_for_loss does not necessarily
+ # match the torch.Tensor requiresment, should be improved further
+
+ batch.update(batch_info)
+ batch_for_loss.update(batch_info)
+
+ loss = self.train_lossfunc(batch, batch_for_loss)
+
+ state = {'field':'iteration', "test_loss": loss.detach()}
+ self.call_plugins(queue_name='iteration', time=self.iter, **state)
+ self.iter += 1
+
+ return loss.detach()
+
+ def epoch(self) -> None:
+
+ for ibatch in self.test_loader:
+ # iter with different structure
+ self.iteration(ibatch)
\ No newline at end of file
diff --git a/dptb/nnops/trainer.py b/dptb/nnops/trainer.py
new file mode 100644
index 00000000..1f009968
--- /dev/null
+++ b/dptb/nnops/trainer.py
@@ -0,0 +1,218 @@
+import torch
+import logging
+from dptb.utils.tools import get_lr_scheduler, \
+get_optimizer, j_must_have
+from dptb.nnops.base_trainer import BaseTrainer
+from typing import Union, Optional
+from dptb.data import AtomicDataset, DataLoader, AtomicData
+from dptb.nn import build_model
+from dptb.nnops.loss import Loss
+
+log = logging.getLogger(__name__)
+#TODO: complete the log output for initilizing the trainer
+
+class Trainer(BaseTrainer):
+
+ object_keys = ["lr_scheduler", "optimizer"]
+
+ def __init__(
+ self,
+ train_options: dict,
+ common_options: dict,
+ model: torch.nn.Module,
+ train_datasets: AtomicDataset,
+ reference_datasets: Union[AtomicDataset, None]=None,
+ validation_datasets: Union[AtomicDataset, None]=None,
+ ) -> None:
+ super(Trainer, self).__init__(dtype=common_options["dtype"], device=common_options["device"])
+
+ # init the object
+ self.model = model.to(self.device)
+ self.optimizer = get_optimizer(model_param=self.model.parameters(), **train_options["optimizer"])
+ self.lr_scheduler = get_lr_scheduler(optimizer=self.optimizer, **train_options["lr_scheduler"]) # add optmizer
+ self.common_options = common_options
+ self.train_options = train_options
+
+ self.train_datasets = train_datasets
+ self.use_reference = False
+ if reference_datasets is not None:
+ self.reference_datesets = reference_datasets
+ self.use_reference = True
+
+ if validation_datasets is not None:
+ self.validation_datasets = validation_datasets
+ self.use_validation = True
+ else:
+ self.use_validation = False
+
+ self.train_loader = DataLoader(dataset=self.train_datasets, batch_size=train_options["batch_size"], shuffle=True)
+
+ if self.use_reference:
+ self.reference_loader = DataLoader(dataset=self.reference_datesets, batch_size=train_options["ref_batch_size"], shuffle=True)
+
+ if self.use_validation:
+ self.validation_loader = DataLoader(dataset=self.validation_datasets, batch_size=train_options["val_batch_size"], shuffle=False)
+
+ # loss function
+ self.train_lossfunc = Loss(**train_options["loss_options"]["train"], **common_options, idp=self.model.hamiltonian.idp)
+ if self.use_validation:
+ self.validation_lossfunc = Loss(**train_options["loss_options"]["validation"], **common_options, idp=self.model.hamiltonian.idp)
+ if self.use_reference:
+ self.reference_lossfunc = Loss(**train_options["loss_options"]["reference"], **common_options, idp=self.model.hamiltonian.idp)
+
+ def iteration(self, batch, ref_batch=None):
+ '''
+ conduct one step forward computation, used in train, test and validation.
+ '''
+ self.model.train()
+ self.optimizer.zero_grad(set_to_none=True)
+ batch = batch.to(self.device)
+
+ # record the batch_info to help reconstructing sub-graph from the batch
+ batch_info = {
+ "__slices__": batch.__slices__,
+ "__cumsum__": batch.__cumsum__,
+ "__cat_dims__": batch.__cat_dims__,
+ "__num_nodes_list__": batch.__num_nodes_list__,
+ "__data_class__": batch.__data_class__,
+ }
+
+ batch = AtomicData.to_AtomicDataDict(batch)
+
+ batch_for_loss = batch.copy() # make a shallow copy in case the model change the batch data
+
+ batch = self.model(batch)
+
+ #TODO: this could make the loss function unjitable since t he batchinfo in batch and batch_for_loss does not necessarily
+ # match the torch.Tensor requiresment, should be improved further
+
+ batch.update(batch_info)
+ batch_for_loss.update(batch_info)
+
+ loss = self.train_lossfunc(batch, batch_for_loss)
+
+ if ref_batch is not None:
+ ref_batch = ref_batch.to(self.device) # AtomicData Type
+ batch_info = {
+ "__slices__": batch.__slices__,
+ "__cumsum__": batch.__cumsum__,
+ "__cat_dims__": batch.__cat_dims__,
+ "__num_nodes_list__": batch.__num_nodes_list__,
+ "__data_class__": batch.__data_class__,
+ }
+
+ ref_batch = AtomicData.to_AtomicDataDict(ref_batch) # AtomicDataDict Type
+ ref_batch_for_loss = ref_batch.copy()
+ ref_batch = self.model(ref_batch)
+
+ ref_batch.update(batch_info)
+ ref_batch_for_loss.update(batch_info)
+
+ loss += self.train_lossfunc(ref_batch, ref_batch_for_loss)
+
+ self.optimizer.zero_grad(set_to_none=True)
+ loss.backward()
+ #TODO: add clip large gradient
+ self.optimizer.step()
+
+ state = {'field':'iteration', "train_loss": loss.detach(), "lr": self.optimizer.state_dict()["param_groups"][0]['lr']}
+ self.call_plugins(queue_name='iteration', time=self.iter, **state)
+ self.iter += 1
+
+ #TODO: add EMA
+
+ return loss.detach()
+
+ @classmethod
+ def restart(
+ cls,
+ checkpoint: str,
+ train_datasets: AtomicDataset,
+ train_options: dict={},
+ common_options: dict={},
+ reference_datasets: Optional[AtomicDataset]=None,
+ validation_datasets: Optional[AtomicDataset]=None,
+ ):
+ """restart the training from a checkpoint, it does not support model options change."""
+
+ ckpt = torch.load(checkpoint)
+
+ run_opt = {
+ "restart": checkpoint,
+ }
+
+ model = build_model(run_opt, ckpt["config"]["model_options"], ckpt["config"]["common_options"])
+ if len(train_options) == 0:
+ train_options = ckpt["config"]["train_options"]
+ if len(common_options) == 0:
+ common_options = ckpt["config"]["common_options"]
+
+ # init trainer and load the trainer's states
+ trainer = cls(
+ model=model,
+ train_datasets=train_datasets,
+ reference_datasets=reference_datasets,
+ validation_datasets=validation_datasets,
+ train_options=train_options,
+ common_options=common_options,
+ )
+
+ trainer.ep = ckpt["epoch"]
+ trainer.iter = ckpt["iteration"]
+ trainer.stats = ckpt["stats"]
+
+ queues_name = list(trainer.plugin_queues.keys())
+ for unit in queues_name:
+ for plugin in trainer.plugin_queues[unit]:
+ plugin = (getattr(trainer, unit) + plugin[0], plugin[1], plugin[2])
+
+ for key in Trainer.object_keys:
+ item = getattr(trainer, key, None)
+ if item is not None:
+ item.load_state_dict(ckpt[key+"_state_dict"])
+
+ return trainer
+#
+
+ def epoch(self) -> None:
+
+ for ibatch in self.train_loader:
+ # iter with different structure
+ if self.use_reference:
+ self.iteration(ibatch, next(self.reference_loader))
+ else:
+ self.iteration(ibatch)
+
+ def update(self, **kwargs):
+ pass
+
+ def validation(self, fast=True):
+ with torch.no_grad():
+ loss = torch.scalar_tensor(0., dtype=self.dtype, device=self.device)
+ self.model.eval()
+
+ for batch in self.validation_loader:
+ batch = batch.to(self.device)
+
+ batch_info = {
+ "__slices__": batch.__slices__,
+ "__cumsum__": batch.__cumsum__,
+ "__cat_dims__": batch.__cat_dims__,
+ "__num_nodes_list__": batch.__num_nodes_list__,
+ "__data_class__": batch.__data_class__,
+ }
+
+ batch = AtomicData.to_AtomicDataDict(batch)
+
+ batch_for_loss = batch.copy()
+ batch = self.model(batch)
+
+ batch.update(batch_info)
+ batch_for_loss.update(batch_info)
+
+ loss += self.validation_lossfunc(batch, batch_for_loss)
+
+ if fast:
+ break
+
+ return loss
diff --git a/dptb/nnops/use_e3baseline.ipynb b/dptb/nnops/use_e3baseline.ipynb
new file mode 100644
index 00000000..7f33eb95
--- /dev/null
+++ b/dptb/nnops/use_e3baseline.ipynb
@@ -0,0 +1,3630 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/miniconda/envs/deeptb/lib/python3.8/site-packages/torch/jit/_check.py:181: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n",
+ " warnings.warn(\"The TorchScript type system doesn't support \"\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "DPTB(\n",
+ " (embedding): E3BaseLineModelLocal(\n",
+ " (sh): SphericalHarmonics()\n",
+ " (onehot): OneHotAtomEncoding()\n",
+ " (init_layer): InitLayer(\n",
+ " (two_body_latent): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (_env_weighter): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (env_embed_mlp): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (bessel): BesselBasis()\n",
+ " )\n",
+ " (layers): ModuleList(\n",
+ " (0): Layer(\n",
+ " (_env_weighter): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (env_linears): Identity()\n",
+ " (lin_pre): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (activation): Gate (142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e)\n",
+ " (tp): SeparateWeightTensorProduct(\n",
+ " (tp): TensorProduct(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 2178 paths | 2178 weights)\n",
+ " (weights1): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " )\n",
+ " (weights2): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " )\n",
+ " )\n",
+ " (lin_post): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (bn): BatchNorm (64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e, eps=1e-05, momentum=0.1)\n",
+ " (linear_res): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 142 weights)\n",
+ " (latents): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (env_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " )\n",
+ " (1): Layer(\n",
+ " (_env_weighter): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (env_linears): Identity()\n",
+ " (lin_pre): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (activation): Gate (142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e)\n",
+ " (tp): SeparateWeightTensorProduct(\n",
+ " (tp): TensorProduct(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 39396 paths | 39396 weights)\n",
+ " (weights1): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 64x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 64x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 64x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 64x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 64x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 32x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 32x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 8x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 4x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 2x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 2x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " )\n",
+ " (weights2): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " )\n",
+ " )\n",
+ " (lin_post): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (bn): BatchNorm (64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e, eps=1e-05, momentum=0.1)\n",
+ " (linear_res): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (latents): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (env_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " )\n",
+ " (2): Layer(\n",
+ " (_env_weighter): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (env_linears): Identity()\n",
+ " (lin_pre): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (activation): Gate (142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e)\n",
+ " (tp): SeparateWeightTensorProduct(\n",
+ " (tp): TensorProduct(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 39396 paths | 39396 weights)\n",
+ " (weights1): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 64x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 64x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 64x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 64x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 64x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 32x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 32x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 8x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 4x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 2x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 2x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " )\n",
+ " (weights2): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " )\n",
+ " )\n",
+ " (lin_post): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (bn): BatchNorm (64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e, eps=1e-05, momentum=0.1)\n",
+ " (linear_res): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (latents): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (env_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " )\n",
+ " (3): Layer(\n",
+ " (_env_weighter): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (env_linears): Identity()\n",
+ " (lin_pre): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (activation): Gate (142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e)\n",
+ " (tp): SeparateWeightTensorProduct(\n",
+ " (tp): TensorProduct(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 39396 paths | 39396 weights)\n",
+ " (weights1): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 64x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 64x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 64x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 64x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 64x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 32x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 32x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 8x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 4x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 2x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 2x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " )\n",
+ " (weights2): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " )\n",
+ " )\n",
+ " (lin_post): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (bn): BatchNorm (64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e, eps=1e-05, momentum=0.1)\n",
+ " (linear_res): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (latents): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (env_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " )\n",
+ " (4): Layer(\n",
+ " (_env_weighter): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (env_linears): Identity()\n",
+ " (lin_pre): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (activation): Gate (142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e)\n",
+ " (tp): SeparateWeightTensorProduct(\n",
+ " (tp): TensorProduct(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 142x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 39396 paths | 39396 weights)\n",
+ " (weights1): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 64x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 64x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 64x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 64x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 64x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 64x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 32x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 32x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 32x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 32x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 32x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 16x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 16x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 16x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 16x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 8x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 8x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 8x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 8x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 4x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 4x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 4x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 4x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 2x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 2x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 2x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 2x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " )\n",
+ " (weights2): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x32 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x4 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x142 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x16 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " )\n",
+ " )\n",
+ " (lin_post): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (bn): BatchNorm (64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e, eps=1e-05, momentum=0.1)\n",
+ " (linear_res): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (latents): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (env_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " )\n",
+ " (5): Layer(\n",
+ " (_env_weighter): Linear(1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e | 7 weights)\n",
+ " (_node_weighter): E3ElementLinear()\n",
+ " (_edge_weighter): E3ElementLinear()\n",
+ " (env_linears): Identity()\n",
+ " (lin_pre): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e | 5716 weights)\n",
+ " (activation): Gate (72x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e -> 10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e)\n",
+ " (tp): SeparateWeightTensorProduct(\n",
+ " (tp): TensorProduct(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 72x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e | 22988 paths | 22988 weights)\n",
+ " (weights1): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 64x72 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 64x10 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 64x13 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 64x8 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 64x6 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 64x2 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 64x1 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 32x10 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 32x72 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 32x7 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 32x13 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 32x10 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 32x6 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 32x13 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 32x6 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 32x6 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 32x8 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 32x2 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 32x2 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 32x6 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 32x1 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 32x1 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 32x2 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 16x13 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 16x10 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 16x72 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 16x7 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 16x13 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 16x10 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 16x13 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 16x13 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 16x10 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 16x72 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 16x7 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 16x13 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 16x10 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 16x13 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 16x6 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 16x1 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 16x8 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 16x2 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 8x13 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 8x10 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 8x72 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 8x7 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 8x13 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 8x10 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 8x13 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (106): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (107): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (108): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (109): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (110): Parameter containing: [torch.float32 of size 4x6 (GPU 0)]\n",
+ " (111): Parameter containing: [torch.float32 of size 4x1 (GPU 0)]\n",
+ " (112): Parameter containing: [torch.float32 of size 4x1 (GPU 0)]\n",
+ " (113): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (114): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (115): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (116): Parameter containing: [torch.float32 of size 4x13 (GPU 0)]\n",
+ " (117): Parameter containing: [torch.float32 of size 4x6 (GPU 0)]\n",
+ " (118): Parameter containing: [torch.float32 of size 4x6 (GPU 0)]\n",
+ " (119): Parameter containing: [torch.float32 of size 4x1 (GPU 0)]\n",
+ " (120): Parameter containing: [torch.float32 of size 4x1 (GPU 0)]\n",
+ " (121): Parameter containing: [torch.float32 of size 4x10 (GPU 0)]\n",
+ " (122): Parameter containing: [torch.float32 of size 4x6 (GPU 0)]\n",
+ " (123): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (124): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (125): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (126): Parameter containing: [torch.float32 of size 4x72 (GPU 0)]\n",
+ " (127): Parameter containing: [torch.float32 of size 4x7 (GPU 0)]\n",
+ " (128): Parameter containing: [torch.float32 of size 4x13 (GPU 0)]\n",
+ " (129): Parameter containing: [torch.float32 of size 4x6 (GPU 0)]\n",
+ " (130): Parameter containing: [torch.float32 of size 4x6 (GPU 0)]\n",
+ " (131): Parameter containing: [torch.float32 of size 4x1 (GPU 0)]\n",
+ " (132): Parameter containing: [torch.float32 of size 4x1 (GPU 0)]\n",
+ " (133): Parameter containing: [torch.float32 of size 4x10 (GPU 0)]\n",
+ " (134): Parameter containing: [torch.float32 of size 4x6 (GPU 0)]\n",
+ " (135): Parameter containing: [torch.float32 of size 4x8 (GPU 0)]\n",
+ " (136): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (137): Parameter containing: [torch.float32 of size 4x2 (GPU 0)]\n",
+ " (138): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (139): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (140): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (141): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (142): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (143): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (144): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (145): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (146): Parameter containing: [torch.float32 of size 2x13 (GPU 0)]\n",
+ " (147): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (148): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (149): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (150): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (151): Parameter containing: [torch.float32 of size 2x10 (GPU 0)]\n",
+ " (152): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (153): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (154): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (155): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (156): Parameter containing: [torch.float32 of size 2x72 (GPU 0)]\n",
+ " (157): Parameter containing: [torch.float32 of size 2x7 (GPU 0)]\n",
+ " (158): Parameter containing: [torch.float32 of size 2x13 (GPU 0)]\n",
+ " (159): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (160): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (161): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (162): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " )\n",
+ " (weights2): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x72 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x72 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x72 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x72 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x72 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (106): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (107): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (108): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (109): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (110): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (111): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (112): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (113): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (114): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (115): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (116): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (117): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (118): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (119): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (120): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (121): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (122): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (123): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (124): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (125): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (126): Parameter containing: [torch.float32 of size 1x72 (GPU 0)]\n",
+ " (127): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (128): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (129): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (130): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (131): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (132): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (133): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (134): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (135): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (136): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (137): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (138): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (139): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (140): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (141): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (142): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (143): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (144): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (145): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (146): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (147): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (148): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (149): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (150): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (151): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (152): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (153): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (154): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (155): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (156): Parameter containing: [torch.float32 of size 1x72 (GPU 0)]\n",
+ " (157): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (158): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (159): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (160): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (161): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (162): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " )\n",
+ " )\n",
+ " (tp_out): SeparateWeightTensorProduct(\n",
+ " (tp): TensorProduct(10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e+1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e+1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o+1x6e -> 10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e | 10971 paths | 10971 weights)\n",
+ " (weights1): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 10x10 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 10x10 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 10x13 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 10x8 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 10x6 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 10x2 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 10x1 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 10x10 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 10x10 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 10x7 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 10x13 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 10x10 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 10x6 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 10x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 10x13 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 10x6 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 10x6 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 10x8 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 10x2 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 10x2 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 10x6 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 10x1 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 10x1 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 10x2 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 7x7 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 7x10 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 7x6 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 7x7 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 7x13 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 7x6 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 7x6 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 7x8 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 7x2 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 7x6 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 7x6 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 7x1 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 7x2 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 7x2 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 7x1 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 7x1 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 6x7 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 6x13 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 6x10 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 6x7 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 6x13 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 13x13 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 13x10 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 13x6 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 13x8 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 13x10 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 13x7 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 13x13 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 13x6 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 13x6 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 13x10 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 13x6 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 13x8 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 13x2 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 13x2 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 13x13 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 13x6 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 13x6 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 13x1 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 13x1 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 13x8 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 13x2 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 13x2 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 13x6 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 13x1 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 13x1 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 8x13 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 8x10 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 8x10 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 8x7 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 8x13 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 8x10 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (106): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (107): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (108): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (109): Parameter containing: [torch.float32 of size 8x13 (GPU 0)]\n",
+ " (110): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (111): Parameter containing: [torch.float32 of size 8x6 (GPU 0)]\n",
+ " (112): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (113): Parameter containing: [torch.float32 of size 8x1 (GPU 0)]\n",
+ " (114): Parameter containing: [torch.float32 of size 8x8 (GPU 0)]\n",
+ " (115): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (116): Parameter containing: [torch.float32 of size 8x2 (GPU 0)]\n",
+ " (117): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (118): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (119): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (120): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (121): Parameter containing: [torch.float32 of size 6x7 (GPU 0)]\n",
+ " (122): Parameter containing: [torch.float32 of size 6x13 (GPU 0)]\n",
+ " (123): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (124): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (125): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (126): Parameter containing: [torch.float32 of size 6x10 (GPU 0)]\n",
+ " (127): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (128): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (129): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (130): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (131): Parameter containing: [torch.float32 of size 6x7 (GPU 0)]\n",
+ " (132): Parameter containing: [torch.float32 of size 6x13 (GPU 0)]\n",
+ " (133): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (134): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (135): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (136): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (137): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (138): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (139): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (140): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (141): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (142): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (143): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (144): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (145): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (146): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (147): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (148): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (149): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (150): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (151): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (152): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (153): Parameter containing: [torch.float32 of size 2x7 (GPU 0)]\n",
+ " (154): Parameter containing: [torch.float32 of size 2x13 (GPU 0)]\n",
+ " (155): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (156): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (157): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (158): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (159): Parameter containing: [torch.float32 of size 2x10 (GPU 0)]\n",
+ " (160): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (161): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (162): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (163): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (164): Parameter containing: [torch.float32 of size 2x7 (GPU 0)]\n",
+ " (165): Parameter containing: [torch.float32 of size 2x13 (GPU 0)]\n",
+ " (166): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (167): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (168): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (169): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (170): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (171): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (172): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (173): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (174): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (175): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (176): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (177): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (178): Parameter containing: [torch.float32 of size 6x13 (GPU 0)]\n",
+ " (179): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (180): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (181): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (182): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (183): Parameter containing: [torch.float32 of size 6x10 (GPU 0)]\n",
+ " (184): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (185): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (186): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (187): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (188): Parameter containing: [torch.float32 of size 6x10 (GPU 0)]\n",
+ " (189): Parameter containing: [torch.float32 of size 6x7 (GPU 0)]\n",
+ " (190): Parameter containing: [torch.float32 of size 6x13 (GPU 0)]\n",
+ " (191): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (192): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (193): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (194): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (195): Parameter containing: [torch.float32 of size 6x10 (GPU 0)]\n",
+ " (196): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (197): Parameter containing: [torch.float32 of size 6x8 (GPU 0)]\n",
+ " (198): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (199): Parameter containing: [torch.float32 of size 6x2 (GPU 0)]\n",
+ " (200): Parameter containing: [torch.float32 of size 6x13 (GPU 0)]\n",
+ " (201): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (202): Parameter containing: [torch.float32 of size 6x6 (GPU 0)]\n",
+ " (203): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (204): Parameter containing: [torch.float32 of size 6x1 (GPU 0)]\n",
+ " (205): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (206): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (207): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (208): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (209): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (210): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (211): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (212): Parameter containing: [torch.float32 of size 2x13 (GPU 0)]\n",
+ " (213): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (214): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (215): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (216): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (217): Parameter containing: [torch.float32 of size 2x10 (GPU 0)]\n",
+ " (218): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (219): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (220): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (221): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (222): Parameter containing: [torch.float32 of size 2x10 (GPU 0)]\n",
+ " (223): Parameter containing: [torch.float32 of size 2x7 (GPU 0)]\n",
+ " (224): Parameter containing: [torch.float32 of size 2x13 (GPU 0)]\n",
+ " (225): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (226): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (227): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (228): Parameter containing: [torch.float32 of size 2x1 (GPU 0)]\n",
+ " (229): Parameter containing: [torch.float32 of size 2x10 (GPU 0)]\n",
+ " (230): Parameter containing: [torch.float32 of size 2x6 (GPU 0)]\n",
+ " (231): Parameter containing: [torch.float32 of size 2x8 (GPU 0)]\n",
+ " (232): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (233): Parameter containing: [torch.float32 of size 2x2 (GPU 0)]\n",
+ " (234): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (235): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (236): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (237): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (238): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (239): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (240): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (241): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (242): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (243): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (244): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (245): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (246): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (247): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (248): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (249): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (250): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (251): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (252): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (253): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (254): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (255): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (256): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (257): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (258): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (259): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (260): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (261): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (262): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (263): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (264): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (265): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (266): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (267): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (268): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (269): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (270): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (271): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (272): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (273): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (274): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (275): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (276): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (277): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (278): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (279): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (280): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (281): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (282): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (283): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (284): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (285): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (286): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (287): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (288): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (289): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (290): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (291): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (292): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (293): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (294): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (295): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (296): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (297): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (298): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (299): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (300): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (301): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (302): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (303): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (304): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (305): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (306): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (307): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (308): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (309): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (310): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (311): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (312): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (313): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (314): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (315): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (316): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (317): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (318): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (319): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (320): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (321): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (322): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (323): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (324): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (325): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (326): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (327): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (328): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (329): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (330): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (331): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (332): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (333): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (334): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (335): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (336): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (337): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (338): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (339): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (340): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (341): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (342): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (343): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (344): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (345): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (346): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (347): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (348): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (349): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (350): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (351): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (352): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (353): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (354): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (355): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (356): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (357): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (358): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (359): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (360): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (361): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (362): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (363): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (364): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (365): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (366): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (367): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (368): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (369): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (370): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (371): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (372): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (373): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (374): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (375): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (376): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (377): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (378): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (379): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (380): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (381): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (382): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (383): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (384): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (385): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (386): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (387): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (388): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (389): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (390): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (391): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (392): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (393): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (394): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (395): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (396): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (397): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (398): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (399): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (400): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (401): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (402): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (403): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (404): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (405): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (406): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (407): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (408): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (409): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (410): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (411): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (412): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (413): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (414): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (415): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (416): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (417): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (418): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (419): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (420): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (421): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (422): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (423): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (424): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (425): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (426): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (427): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (428): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (429): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (430): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (431): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (432): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (433): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (434): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (435): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (436): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (437): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (438): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (439): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (440): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (441): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (442): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (443): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (444): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (445): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (446): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (447): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (448): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (449): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (450): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (451): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (452): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (453): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (454): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (455): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (456): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (457): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (458): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (459): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (460): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (461): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (462): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (463): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (464): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (465): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (466): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (467): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (468): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (469): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (470): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (471): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (472): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (473): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (474): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (475): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (476): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (477): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (478): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (479): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (480): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (481): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (482): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (483): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (484): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (485): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (486): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (487): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (488): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (489): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (490): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (491): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (492): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (493): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (494): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (495): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (496): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (497): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (498): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (499): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (500): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (501): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (502): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (503): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (504): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (505): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (506): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (507): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (508): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (509): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (510): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (511): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (512): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (513): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (514): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (515): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (516): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (517): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (518): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (519): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (520): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (521): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (522): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (523): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (524): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (525): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (526): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (527): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (528): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (529): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (530): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (531): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (532): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (533): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (534): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (535): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (536): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (537): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (538): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (539): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (540): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (541): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (542): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (543): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (544): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (545): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (546): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (547): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (548): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (549): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (550): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (551): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (552): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (553): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (554): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (555): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (556): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (557): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (558): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (559): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (560): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (561): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (562): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (563): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (564): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (565): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (566): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (567): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (568): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (569): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (570): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (571): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (572): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (573): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (574): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (575): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (576): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (577): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (578): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (579): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (580): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (581): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (582): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (583): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (584): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (585): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (586): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (587): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (588): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (589): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (590): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (591): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (592): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (593): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (594): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (595): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (596): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (597): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (598): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (599): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (600): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (601): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (602): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (603): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (604): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (605): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (606): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (607): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (608): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (609): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (610): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (611): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (612): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " )\n",
+ " (weights2): ParameterList(\n",
+ " (0): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (1): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (2): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (3): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (4): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (5): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (6): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (7): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (8): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (9): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (10): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (11): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (12): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (13): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (14): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (15): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (16): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (17): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (18): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (19): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (20): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (21): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (22): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (23): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (24): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (25): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (26): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (27): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (28): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (29): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (30): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (31): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (32): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (33): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (34): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (35): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (36): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (37): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (38): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (39): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (40): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (41): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (42): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (43): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (44): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (45): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (46): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (47): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (48): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (49): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (50): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (51): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (52): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (53): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (54): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (55): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (56): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (57): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (58): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (59): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (60): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (61): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (62): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (63): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (64): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (65): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (66): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (67): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (68): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (69): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (70): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (71): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (72): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (73): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (74): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (75): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (76): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (77): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (78): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (79): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (80): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (81): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (82): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (83): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (84): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (85): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (86): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (87): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (88): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (89): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (90): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (91): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (92): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (93): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (94): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (95): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (96): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (97): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (98): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (99): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (100): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (101): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (102): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (103): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (104): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (105): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (106): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (107): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (108): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (109): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (110): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (111): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (112): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (113): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (114): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (115): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (116): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (117): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (118): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (119): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (120): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (121): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (122): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (123): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (124): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (125): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (126): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (127): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (128): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (129): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (130): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (131): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (132): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (133): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (134): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (135): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (136): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (137): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (138): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (139): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (140): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (141): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (142): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (143): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (144): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (145): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (146): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (147): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (148): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (149): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (150): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (151): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (152): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (153): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (154): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (155): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (156): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (157): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (158): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (159): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (160): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (161): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (162): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (163): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (164): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (165): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (166): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (167): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (168): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (169): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (170): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (171): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (172): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (173): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (174): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (175): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (176): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (177): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (178): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (179): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (180): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (181): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (182): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (183): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (184): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (185): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (186): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (187): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (188): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (189): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (190): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (191): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (192): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (193): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (194): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (195): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (196): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (197): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (198): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (199): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (200): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (201): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (202): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (203): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (204): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (205): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (206): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (207): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (208): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (209): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (210): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (211): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (212): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (213): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (214): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (215): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (216): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (217): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (218): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (219): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (220): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (221): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (222): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (223): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (224): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (225): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (226): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (227): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (228): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (229): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (230): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (231): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (232): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (233): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (234): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (235): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (236): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (237): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (238): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (239): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (240): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (241): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (242): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (243): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (244): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (245): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (246): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (247): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (248): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (249): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (250): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (251): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (252): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (253): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (254): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (255): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (256): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (257): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (258): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (259): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (260): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (261): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (262): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (263): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (264): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (265): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (266): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (267): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (268): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (269): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (270): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (271): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (272): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (273): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (274): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (275): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (276): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (277): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (278): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (279): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (280): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (281): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (282): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (283): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (284): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (285): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (286): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (287): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (288): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (289): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (290): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (291): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (292): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (293): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (294): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (295): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (296): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (297): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (298): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (299): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (300): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (301): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (302): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (303): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (304): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (305): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (306): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (307): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (308): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (309): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (310): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (311): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (312): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (313): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (314): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (315): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (316): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (317): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (318): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (319): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (320): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (321): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (322): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (323): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (324): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (325): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (326): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (327): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (328): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (329): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (330): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (331): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (332): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (333): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (334): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (335): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (336): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (337): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (338): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (339): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (340): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (341): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (342): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (343): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (344): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (345): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (346): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (347): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (348): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (349): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (350): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (351): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (352): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (353): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (354): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (355): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (356): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (357): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (358): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (359): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (360): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (361): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (362): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (363): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (364): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (365): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (366): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (367): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (368): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (369): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (370): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (371): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (372): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (373): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (374): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (375): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (376): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (377): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (378): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (379): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (380): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (381): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (382): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (383): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (384): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (385): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (386): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (387): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (388): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (389): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (390): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (391): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (392): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (393): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (394): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (395): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (396): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (397): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (398): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (399): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (400): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (401): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (402): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (403): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (404): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (405): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (406): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (407): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (408): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (409): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (410): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (411): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (412): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (413): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (414): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (415): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (416): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (417): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (418): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (419): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (420): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (421): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (422): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (423): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (424): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (425): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (426): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (427): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (428): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (429): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (430): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (431): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (432): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (433): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (434): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (435): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (436): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (437): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (438): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (439): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (440): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (441): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (442): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (443): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (444): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (445): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (446): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (447): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (448): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (449): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (450): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (451): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (452): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (453): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (454): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (455): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (456): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (457): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (458): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (459): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (460): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (461): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (462): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (463): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (464): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (465): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (466): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (467): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (468): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (469): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (470): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (471): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (472): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (473): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (474): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (475): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (476): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (477): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (478): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (479): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (480): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (481): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (482): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (483): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (484): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (485): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (486): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (487): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (488): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (489): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (490): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (491): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (492): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (493): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (494): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (495): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (496): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (497): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (498): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (499): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (500): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (501): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (502): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (503): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (504): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (505): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (506): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (507): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (508): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (509): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (510): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (511): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (512): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (513): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (514): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (515): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (516): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (517): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (518): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (519): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (520): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (521): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (522): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (523): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (524): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (525): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (526): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (527): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (528): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (529): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (530): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (531): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (532): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (533): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (534): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (535): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (536): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (537): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (538): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (539): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (540): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (541): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (542): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (543): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (544): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (545): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (546): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (547): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (548): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (549): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (550): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (551): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (552): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (553): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (554): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (555): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (556): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (557): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (558): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (559): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (560): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (561): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (562): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (563): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (564): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (565): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (566): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (567): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (568): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (569): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (570): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (571): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (572): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (573): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (574): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (575): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (576): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (577): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (578): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (579): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (580): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (581): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (582): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (583): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (584): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (585): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (586): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (587): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (588): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (589): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (590): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (591): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (592): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (593): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (594): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (595): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (596): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (597): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (598): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (599): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (600): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (601): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (602): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (603): Parameter containing: [torch.float32 of size 1x8 (GPU 0)]\n",
+ " (604): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (605): Parameter containing: [torch.float32 of size 1x2 (GPU 0)]\n",
+ " (606): Parameter containing: [torch.float32 of size 1x10 (GPU 0)]\n",
+ " (607): Parameter containing: [torch.float32 of size 1x7 (GPU 0)]\n",
+ " (608): Parameter containing: [torch.float32 of size 1x13 (GPU 0)]\n",
+ " (609): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (610): Parameter containing: [torch.float32 of size 1x6 (GPU 0)]\n",
+ " (611): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " (612): Parameter containing: [torch.float32 of size 1x1 (GPU 0)]\n",
+ " )\n",
+ " )\n",
+ " (lin_post): Linear(10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e -> 10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e | 600 weights)\n",
+ " (bn): BatchNorm (10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e, eps=1e-05, momentum=0.1)\n",
+ " (linear_res): Linear(64x0e+32x1o+16x2e+16x3o+8x4e+4x5o+2x6e -> 10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e | 1354 weights)\n",
+ " (latents): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (env_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (node_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " (edge_embed_mlps): ScalarMLPFunction(\n",
+ " (_forward): RecursiveScriptModule(original_name=GraphModule)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (out_edge): Linear(10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e -> 1x0e+1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x2e+1x2e+1x2e+1x2e+1x3o+1x3o+1x0e+1x1e+1x2e+1x0e+1x1e+1x2e+1x0e+1x1e+1x2e+1x1o+1x2o+1x3o+1x1o+1x2o+1x3o+1x1o+1x2o+1x3o+1x1o+1x2o+1x3o+1x2e+1x3e+1x4e+1x2e+1x3e+1x4e+1x0e+1x1e+1x2e+1x3e+1x4e+1x0e+1x1e+1x2e+1x3e+1x4e+1x0e+1x1e+1x2e+1x3e+1x4e+1x1o+1x2o+1x3o+1x4o+1x5o+1x1o+1x2o+1x3o+1x4o+1x5o+1x0e+1x1e+1x2e+1x3e+1x4e+1x5e+1x6e | 600 weights)\n",
+ " (out_node): Linear(10x0e+10x1o+7x1e+6x2o+13x2e+8x3o+6x3e+2x4o+6x4e+2x5o+1x5e+1x6e -> 1x0e+1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x2e+1x2e+1x2e+1x2e+1x3o+1x3o+1x0e+1x1e+1x2e+1x0e+1x1e+1x2e+1x0e+1x1e+1x2e+1x1o+1x2o+1x3o+1x1o+1x2o+1x3o+1x1o+1x2o+1x3o+1x1o+1x2o+1x3o+1x2e+1x3e+1x4e+1x2e+1x3e+1x4e+1x0e+1x1e+1x2e+1x3e+1x4e+1x0e+1x1e+1x2e+1x3e+1x4e+1x0e+1x1e+1x2e+1x3e+1x4e+1x1o+1x2o+1x3o+1x4o+1x5o+1x1o+1x2o+1x3o+1x4o+1x5o+1x0e+1x1e+1x2e+1x3e+1x4e+1x5e+1x6e | 600 weights)\n",
+ " )\n",
+ " (node_prediction_h): E3PerSpeciesScaleShift()\n",
+ " (edge_prediction_h): E3PerEdgeSpeciesScaleShift()\n",
+ " (hamiltonian): E3Hamiltonian()\n",
+ ")"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from dptb.nnops.trainer import Trainer\n",
+ "from dptb.data import ABACUSInMemoryDataset\n",
+ "from dptb.data.transforms import OrbitalMapper\n",
+ "from dptb.nn import build_model\n",
+ "\n",
+ "from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor\n",
+ "from dptb.plugins.train_logger import Logger\n",
+ "from dptb.plugins.plugins import Saver\n",
+ "import heapq\n",
+ "import logging\n",
+ "from dptb.utils.loggers import set_log_handles\n",
+ "\n",
+ "common_options = {\n",
+ " \"basis\": {\n",
+ " \"Ga\": \"2s2p2d1f\",\n",
+ " \"N\": \"2s2p1d\"\n",
+ " },\n",
+ " # \"basis\":{\"Mo\":\"3s2p2d\", \"S\":\"2s2p1d\"},\n",
+ " \"device\": \"cuda:0\",\n",
+ " \"dtype\": \"float32\",\n",
+ " \"overlap\": False,\n",
+ "}\n",
+ "\n",
+ "root = \"/share/semicond/lmp_abacus/abacus_hse_data/GaN/prod-gan/GaN/sys-000/processed_GaN_pbe\"\n",
+ "train_dataset = ABACUSInMemoryDataset(\n",
+ " root=root,\n",
+ " preprocess_dir=\"/share/semicond/lmp_abacus/abacus_hse_data/GaN/prod-gan/GaN/sys-000/processed_GaN_pbe\",\n",
+ " AtomicData_options={\n",
+ " \"r_max\": 8.0,\n",
+ " \"er_max\": None,\n",
+ " \"oer_max\": None,\n",
+ " \"pbc\": True,\n",
+ " },\n",
+ " type_mapper=OrbitalMapper(basis=common_options[\"basis\"]),\n",
+ ")\n",
+ "\n",
+ "train_options = {\n",
+ " \"seed\": 12070,\n",
+ " \"num_epoch\": 4000,\n",
+ " \"batch_size\": 1,\n",
+ " \"optimizer\": {\n",
+ " \"lr\": 0.01,\n",
+ " \"type\": \"Adam\",\n",
+ " },\n",
+ " \"lr_scheduler\": {\n",
+ " \"type\": \"exp\",\n",
+ " \"gamma\": 0.9995\n",
+ " },\n",
+ " \"loss_options\":{\n",
+ " \"train\":{\"method\": \"eigvals\"}\n",
+ " },\n",
+ " \"save_freq\": 10,\n",
+ " \"validation_freq\": 10,\n",
+ " \"display_freq\": 1\n",
+ "}\n",
+ "\n",
+ "run_opt = {\n",
+ " \"init_model\": \"/root/e3/local/refine_2_6lmax_6l/checkpoint/dptb.iter1001.pth\",\n",
+ " \"restart\": None,\n",
+ " \"freeze\": False,\n",
+ " \"train_soc\": False,\n",
+ " \"log_path\": None,\n",
+ " \"log_level\": None\n",
+ " }\n",
+ "\n",
+ "model_option = {\n",
+ " \"embedding\": {\n",
+ " \"method\": \"e3baseline_local\",\n",
+ " \"r_max\": {\"Ga\":8.1, \"N\":7.1},\n",
+ " \"irreps_hidden\": \"32x0e+32x1o+16x2e+16x3o+16x4e+16x5o+8x6e\",\n",
+ " \"lmax\": 4,\n",
+ " \"n_layers\": 4,\n",
+ " \"n_radial_basis\": 18,\n",
+ " \"env_embed_multiplicity\":1,\n",
+ " \"avg_num_neighbors\": 63,\n",
+ " \"latent_kwargs\": {\n",
+ " \"mlp_latent_dimensions\": [128, 128, 256],\n",
+ " \"mlp_nonlinearity\": \"silu\",\n",
+ " \"mlp_initialization\": \"uniform\"\n",
+ " }\n",
+ " },\n",
+ " \"prediction\":{\n",
+ " \"method\": \"e3tb\",\n",
+ " \"scales_trainable\":True,\n",
+ " \"shifts_trainable\":True\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "model = build_model(run_opt, {}, common_options)\n",
+ "model.to(common_options[\"device\"])\n",
+ "\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "993467\n"
+ ]
+ }
+ ],
+ "source": [
+ "np = 0\n",
+ "for p in model.parameters():\n",
+ " np += p.view(-1).shape[0]\n",
+ "\n",
+ "print(np)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.data import AtomicData\n",
+ "from dptb.data.dataloader import DataLoader\n",
+ "import torch\n",
+ "\n",
+ "loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)\n",
+ "\n",
+ "for data in loader:\n",
+ " ref_data = AtomicData.to_AtomicDataDict(data.to(\"cuda:0\"))\n",
+ " break\n",
+ "# ref_data = AtomicData.to_AtomicDataDict(train_dataset[dN].to(\"cuda:0\"))\n",
+ "with torch.no_grad():\n",
+ " data = model(ref_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dptb.nnops.loss import HamilLossAnalysis\n",
+ "\n",
+ "ana = HamilLossAnalysis(idp=model.idp, device=model.device, decompose=True)\n",
+ "\n",
+ "ana_result = ana(data, ref_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rmse err for bond N-N: 0.021915080025792122 \t mae err for bond N-N: 0.008099161088466644\n",
+ "rmse err for bond N-Ga: 0.04044164717197418 \t mae err for bond N-Ga: 0.017163095995783806\n",
+ "rmse err for bond Ga-N: 0.03134715557098389 \t mae err for bond Ga-N: 0.009912804700434208\n",
+ "rmse err for bond Ga-Ga: 0.076223224401474 \t mae err for bond Ga-Ga: 0.03573114797472954\n",
+ "rmse err for atom N: 0.13477067649364471 \t mae err for bond N: 0.03793104737997055\n",
+ "rmse err for atom Ga: 0.19352763891220093 \t mae err for bond Ga: 0.052813541144132614\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAABkoAAAEpCAYAAADccn5yAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABRBUlEQVR4nO3de1xU1f7/8fcAcvECpgiIoZBhaiooJqKWVpzAzKIMlVQUPfbtQqmUJR4VzRK1Y6FikuWtjh49lnJKy+JQ2kXybmWpqUlWCnhJSUxQmN8f/pycGJQZucm8no/HftSs/Vl7Pns7bJj5zFrLYDQajQIAAAAAAAAAALBDDtWdAAAAAAAAAAAAQHWhUAIAAAAAAAAAAOwWhRIAAAAAAAAAAGC3KJQAAAAAAAAAAAC7RaEEAAAAAAAAAADYLQolAAAAAAAAAADAblEoAQAAAAAAAAAAdotCCQAAAAAAAAAAsFsUSgAAAAAAAAAAgN2iUAIAAABch5YsWSKDwaDs7Gyz9pdfflk33XSTHB0dFRwcLEny9/fXsGHDqjzH61Fubq4efvhhNW7cWAaDQSkpKRbjsrOzZTAY9M9//rNqE7wKg8GgyZMnV3caAAAAwHWFQgkAAABQS3z88cd67rnn1L17dy1evFjTpk2r7pSuO2PGjNFHH32kxMREvf3224qMjKzulCrckSNHNHnyZO3atau6U7GZwWCQwWDQrFmzSu27VETctm3bVY8zbNgwGQwGdejQQUaj0eLzxMfHV0jOAAAAqLmcqjsBAAAAANYbMmSIBg4cKBcXF1PbJ598IgcHBy1cuFDOzs6m9n379snBge9Ilccnn3yiBx54QM8++2x1p1Jpjhw5oilTpsjf39806uh69fLLL+vxxx9X3bp1r+k43377rVavXq1+/fpVUGYAAAC4nvBuCQAAADVSQUFBdadQozk6OsrV1VUGg8HUlpeXJzc3N7MiiSS5uLioTp06VZ1iKWfPnrXYfuHCBRUVFV3TsSvq9ZKXl6eGDRtWyLFQuYKDg5Wbm6u0tLRrOo6bm5tatWqlF154weKoEgAAANR+FEoAAABQ7SZPniyDwaDvv/9ejzzyiG644Qb16NFD0sX1Ne677z5t2LBBnTt3lpubm9q3b68NGzZIklavXq327dvL1dVVISEh2rlzZ6nj7927Vw8//LAaNWokV1dXde7cWe+99165cluxYoVCQkLUoEEDubu7q3379po9e7Zp/6Vpfj777DP93//9nxo3bix3d3fFxsbqt99+K3W8Dz/8ULfffrvq1aunBg0aqE+fPvruu+8s5ty/f381adJEbm5uuuWWW/SPf/yj1PNeWqPEYDBo8eLFKigoME1LtGTJEtM1/OsaJadOndKYMWPk7+8vFxcX3XjjjYqNjdXx48evek3+9a9/KSQkRG5ubmrUqJEGDhyon3/+2SymV69eateunbZv36477rhDdevW1fjx483W9khJSVHLli3l4uKi77//XtLFER2Xrk/Dhg31wAMPaM+ePWbHvtLrpSw//vijoqOj1ahRI9WtW1ddu3bVunXrSl1Po9GoefPmma5hebz66qtq0aKF3Nzc1LNnT+3evbtUjDXndeDAAQ0bNkwNGzaUh4eH4uLiShWZCgsLNWbMGDVp0kQNGjTQ/fffr19++eWquW7YsEG33XabJCkuLs7stZKUlKQ6dero2LFjpfo9+uijatiwoc6dOyfpz5/Ljz/+WMHBwXJ1dVXbtm21evXqUn1PnTql0aNHy8/PTy4uLrr55ps1Y8YMlZSUmMUdPXpUe/fu1fnz5696HpLUvXt33XXXXZo5c6b++OOPcvWxxMHBQRMmTNA333yjNWvW2HwcAAAAXL8olAAAAKDGiI6O1tmzZzVt2jSNHDnS1H7gwAE98sgj6tu3r5KTk/Xbb7+pb9++WrZsmcaMGaPBgwdrypQpOnjwoPr372/2Aex3332nrl27as+ePRo3bpxmzZqlevXqKSoq6qofimZkZCgmJkY33HCDZsyYoenTp6tXr1768ssvS8XGx8drz549mjx5smJjY7Vs2TJFRUWZfUP97bffVp8+fVS/fn3NmDFDEydO1Pfff68ePXqYLcr+zTffKDQ0VJ988olGjhyp2bNnKyoqSu+//36Zub799tu6/fbb5eLiorfffltvv/227rjjDouxZ86c0e233665c+fqnnvu0ezZs/XYY49p7969V/2w/aWXXlJsbKwCAwP1yiuvaPTo0crMzNQdd9yhU6dOmcWeOHFCvXv3VnBwsFJSUnTnnXea9i1evFhz587Vo48+qlmzZqlRo0b63//+p4iICOXl5Wny5MlKSEjQpk2b1L1791KL1ktlv17+Kjc3V926ddNHH32kJ554Qi+99JLOnTun+++/3/QauOOOO/T2229Lkv72t7+ZruHVvPXWW5ozZ46efPJJJSYmavfu3brrrruUm5trirH2vPr376/ff/9dycnJ6t+/v5YsWaIpU6aYxfz9739XSkqK7rnnHk2fPl116tRRnz59rppvmzZt9MILL0i6WPy4/LUyZMgQXbhwQStXrjTrU1RUpHfeeUf9+vWTq6urqX3//v0aMGCAevfureTkZDk5OSk6OloZGRmmmLNnz6pnz57617/+pdjYWM2ZM0fdu3dXYmKiEhISzJ4nMTFRbdq00a+//nrV87hk8uTJys3N1fz588vdx5JHHnlEgYGBjCoBAACwV0YAAACgmiUlJRklGWNiYkrta9GihVGScdOmTaa2jz76yCjJ6ObmZvzpp59M7a+//rpRkvHTTz81td19993G9u3bG8+dO2dqKykpMXbr1s0YGBh4xbxGjRpldHd3N164cKHMmMWLFxslGUNCQoxFRUWm9pkzZxolGf/73/8ajUaj8ffffzc2bNjQOHLkSLP+OTk5Rg8PD7P2O+64w9igQQOzc7uU91+f99ChQ6a2oUOHGuvVq1cqxxYtWhiHDh1qejxp0iSjJOPq1atLxV7+HH+VnZ1tdHR0NL700ktm7d9++63RycnJrL1nz55GSca0tDSz2EOHDhklGd3d3Y15eXlm+4KDg41eXl7GEydOmNq+/vpro4ODgzE2NtbUdqXXiyWjR482SjJ+/vnnprbff//dGBAQYPT39zcWFxeb2iUZn3zyyase89J5uLm5GX/55RdT++bNm42SjGPGjLH5vIYPH272XA8++KCxcePGpse7du0ySjI+8cQTZnGPPPKIUZIxKSnpirlv3brVKMm4ePHiUvvCwsKMoaGhZm2rV68u9XN16efy3XffNbWdPn3a2LRpU2PHjh1NbVOnTjXWq1fP+MMPP5gdc9y4cUZHR0fj4cOHTW1Dhw4t9Zouy+X/TnfeeafRx8fHePbsWaPR+OfPxtatW696nMt/ZpYuXVrq56K8rwcAAABc3xhRAgAAgBrjscces9jetm1bhYWFmR6HhoZKku666y41b968VPuPP/4oSTp58qQ++eQT0zf0jx8/ruPHj+vEiROKiIjQ/v37r/jt9YYNG6qgoMDsG/JlefTRR83WAXn88cfl5OSkDz74QNLF0SmnTp1STEyMKY/jx4/L0dFRoaGh+vTTTyVJx44d02effabhw4ebnZukck8FdTXvvvuugoKC9OCDD5bad6XnWL16tUpKStS/f3+zc/Dx8VFgYKDpHC5xcXFRXFycxWP169dPTZo0MT0+evSodu3apWHDhqlRo0am9g4dOuhvf/ub6TperqzXy1998MEH6tKli9n0XPXr19ejjz6q7Oxs07RftoiKilKzZs1Mj7t06aLQ0FBTvhVxXrfffrtOnDih/Px80/lI0tNPP20WN3r0aJvP45LY2Fht3rxZBw8eNLUtW7ZMfn5+6tmzp1msr6+v2Wvo0pRzO3fuVE5OjiRp1apVuv3223XDDTeYvWbCw8NVXFyszz77zNR/yZIlMhqN8vf3tyrnyZMnKycn55rXKhk0aBCjSgAAAOwUhRIAAADUGAEBARbb/1ow8PDwkCT5+flZbL+0NsiBAwdkNBo1ceJENWnSxGxLSkqSdHHx7rI88cQTatWqlXr37q0bb7xRw4cP1/r16y3GBgYGmj2uX7++mjZtappaaf/+/ZIuFnf+msvHH39syuNSkaddu3Zl5nWtDh48aNPx9+/fL6PRqMDAwFLnsGfPnlLXslmzZqUWlr/kr//WP/30kyTplltuKRXbpk0bHT9+vNSC7WW9Xv7qp59+KvO4lz+3Lf767y5JrVq1Mv2723Jef32933DDDZL+fF3/9NNPcnBwUMuWLc3iLD2HtQYMGCAXFxctW7ZMknT69GmtXbtWgwYNKlVEu/nmm0u1tWrVSpLMXvfr168v9XoJDw+XdOWfv/K64447dOedd5a5Vskff/yhnJwcs80SR0dHTZgwQbt27VJ6evo15wUAAIDrh1N1JwAAAABc4ubmZrHd0dHRqvZL3wa/tFbJs88+q4iICIuxN998c5n5eHl5adeuXfroo4/04Ycf6sMPP9TixYsVGxurpUuXltnPkku5vP322/Lx8Sm138mp5v9pXlJSIoPBoA8//NDita9fv77Z47L+Pa+2r7wq4hg10dVe15Xphhtu0H333adly5Zp0qRJeuedd1RYWKjBgwfbdLySkhL97W9/03PPPWdx/6XCyrVKSkpSr1699Prrr6thw4Zm+1auXFlqZFNZ13LQoEGaOnWqXnjhBUVFRVVIbgAAAKj5av67MQAAAMBGN910kySpTp06pm+wW8vZ2Vl9+/ZV3759VVJSoieeeEKvv/66Jk6caFZk2b9/v9li5WfOnNHRo0d17733SpLp2/9eXl5XzOVSzrt377Yp3/Jo2bKlTcdv2bKljEajAgICKuwD7ktatGghSdq3b1+pfXv37pWnp6fq1atn87HLOu7lz22LSyOFLvfDDz+Ypo+qjPNq0aKFSkpKdPDgQbNRJJaew5KrTeEWGxurBx54QFu3btWyZcvUsWNH3XrrraXiLo3Yuvx4P/zwgySZzr9ly5Y6c+aMzT9/5dWzZ0/16tVLM2bM0KRJk8z2RURElGv6POnPUSXDhg3Tf//738pIFQAAADUQU28BAACg1vLy8jJ9y/zo0aOl9h87duyK/U+cOGH22MHBQR06dJAkFRYWmu1bsGCBzp8/b3o8f/58XbhwQb1795Z08cNad3d3TZs2zSzur7k0adJEd9xxhxYtWqTDhw+bxVTUiIJ+/frp66+/1po1a0rtu9JzPPTQQ3J0dNSUKVNKxRmNxlLXyxpNmzZVcHCwli5dqlOnTpnad+/erY8//thUcLLFvffeqy1btigrK8vUVlBQoAULFsjf319t27a1+djp6elm69xs2bJFmzdvNv27V8Z5XTr2nDlzzNpTUlLK1f9SYebyfP56fE9PT82YMUMbN24sczTJkSNHzF5D+fn5euuttxQcHGwaNdW/f39lZWXpo48+KtX/1KlTunDhgunx0aNHtXfvXos/H+Vxaa2SBQsWmLU3bdpU4eHhZtuVDB48WDfffLOmTJliUx4AAAC4/jCiBAAAALXavHnz1KNHD7Vv314jR47UTTfdpNzcXGVlZemXX37R119/XWbfv//97zp58qTuuusu3Xjjjfrpp580d+5cBQcHm9a3uKSoqEh33323+vfvr3379um1115Tjx49dP/990u6uND1/PnzNWTIEHXq1EkDBw5UkyZNdPjwYa1bt07du3dXamqqpIsfgPfo0UOdOnXSo48+qoCAAGVnZ2vdunXatWvXNV+TsWPH6p133lF0dLSGDx+ukJAQnTx5Uu+9957S0tIUFBRksV/Lli314osvKjExUdnZ2YqKilKDBg106NAhrVmzRo8++qieffZZm/N6+eWX1bt3b4WFhWnEiBH6448/NHfuXHl4eGjy5Mk2H3fcuHH697//rd69e+vpp59Wo0aNtHTpUh06dEjvvvuuHBxs//7YzTffrB49eujxxx9XYWGhUlJS1LhxY7Oppir6vIKDgxUTE6PXXntNp0+fVrdu3ZSZmakDBw6Uq3/Lli3VsGFDpaWlqUGDBqpXr55CQ0NNa77UqVNHAwcOVGpqqhwdHRUTE2PxOK1atdKIESO0detWeXt7a9GiRcrNzdXixYtNMWPHjtV7772n++67T8OGDVNISIgKCgr07bff6p133lF2drY8PT0lSYmJiaZ/F2sXdJcujirp2bOnNm7caHXfyzk6Ouof//hHqem6AAAAUHtRKAEAAECt1rZtW23btk1TpkzRkiVLdOLECXl5ealjx46lpuj5q8GDB2vBggV67bXXdOrUKfn4+GjAgAGaPHlyqQ/XU1NTTes6nD9/XjExMZozZ47ZtESPPPKIfH19NX36dL388ssqLCxUs2bNdPvtt5t9KBsUFKSvvvpKEydO1Pz583Xu3Dm1aNFC/fv3r5BrUr9+fX3++edKSkrSmjVrtHTpUnl5eenuu+/WjTfeeMW+48aNU6tWrfTqq6+avnHv5+ene+65x1QUslV4eLjWr1+vpKQkTZo0SXXq1FHPnj01Y8aMci/cbom3t7c2bdqk559/XnPnztW5c+fUoUMHvf/+++rTp8815RwbGysHBwelpKQoLy9PXbp0UWpqqpo2bVqp57Vo0SI1adJEy5YtU3p6uu666y6tW7dOfn5+V+1bp04dLV26VImJiXrsscd04cIFLV682CyX2NhYpaam6u677zY7l8sFBgZq7ty5Gjt2rPbt26eAgACtXLnSbD2gunXrauPGjZo2bZpWrVqlt956S+7u7mrVqpWmTJkiDw8Pm86/LJMnTzabAs9WgwcP1osvvqiDBw9WQFYAAACo6QzGqlgREAAAAKillixZori4OG3dulWdO3eu7nSACvH1118rODhYb731loYMGVJqv7+/v9q1a6e1a9dWQ3YAAABAxWKNEgAAAACAmTfeeEP169fXQw89VN2pAAAAAJWOqbcAAAAAAJKk999/X99//70WLFig+Ph408LvAAAAQG1GoQQAAAAAIEl66qmnlJubq3vvvde0Bg0AAABQ29k09da8efPk7+8vV1dXhYaGasuWLVeMX7VqlVq3bi1XV1e1b99eH3zwgdl+g8FgcXv55ZdtSQ8AAACoMsOGDZPRaGR9EtQK2dnZ+uOPP5Senq4GDRpcMY71SQAAAFBbWF0oWblypRISEpSUlKQdO3YoKChIERERysvLsxi/adMmxcTEaMSIEdq5c6eioqIUFRWl3bt3m2KOHj1qti1atEgGg0H9+vWz/cwAAAAAAAAAAACuwmA0Go3WdAgNDdVtt92m1NRUSVJJSYn8/Pz01FNPady4caXiBwwYoIKCArNvG3Xt2lXBwcFKS0uz+BxRUVH6/ffflZmZaU1qAAAAAAAAAAAAVrFqjZKioiJt375diYmJpjYHBweFh4crKyvLYp+srCwlJCSYtUVERCg9Pd1ifG5urtatW6elS5eWmUdhYaEKCwtNj0tKSnTy5Ek1btxYBoPBijMCAAAAAAAAAAC1jdFo1O+//y5fX185OFx5ci2rCiXHjx9XcXGxvL29zdq9vb21d+9ei31ycnIsxufk5FiMX7p0qRo0aKCHHnqozDySk5NZWBAAAAAAAAAAAFzRzz//rBtvvPGKMVYVSqrCokWLNGjQILm6upYZk5iYaDZK5fTp02revLl+/vlnubu7V0WaAAAAAAAAAACghsrPz5efn58aNGhw1VirCiWenp5ydHRUbm6uWXtubq58fHws9vHx8Sl3/Oeff659+/Zp5cqVV8zDxcVFLi4updrd3d0plAAAAAAAAAAAAEkq13IdV56Y6y+cnZ0VEhJitsh6SUmJMjMzFRYWZrFPWFhYqUXZMzIyLMYvXLhQISEhCgoKsiYtAAAAAAAAAAAAm1g99VZCQoKGDh2qzp07q0uXLkpJSVFBQYHi4uIkSbGxsWrWrJmSk5MlSaNGjVLPnj01a9Ys9enTRytWrNC2bdu0YMECs+Pm5+dr1apVmjVrVgWcFgAAAAAAAAAAwNVZXSgZMGCAjh07pkmTJiknJ0fBwcFav369acH2w4cPm60g361bNy1fvlwTJkzQ+PHjFRgYqPT0dLVr187suCtWrJDRaFRMTMw1nhIAAAAAAAAAAED5GIxGo7G6k7hW+fn58vDw0OnTp1mjBAAAAAAAAACuA8XFxTp//nx1p4HrWJ06deTo6GhxnzV1A6tHlAAAAAAAAAAAYCuj0aicnBydOnWqulNBLdCwYUP5+PiUa9H2slAoAQAAAAAAAABUmUtFEi8vL9WtW/eaPuCG/TIajTp79qzy8vIkSU2bNrX5WBRKAAAAAAAAAABVori42FQkady4cXWng+ucm5ubJCkvL09eXl5lTsN1NQ5XDwEAAAAAAAAA4NpdWpOkbt261ZwJaotLr6VrWe+GESWoMP7j1pU7Nnt6n0rMBAAAAAAAAEBNxnRbqCgV8VpiRAkAAAAAAAAAALBbFEoAAAAAAAAAAIDdYuotAAAAAAAAAEC1s2Zq/2tl7dIAw4YN09KlSyVJTk5OuvHGGxUdHa0XXnhBrq6uprhL00BlZWWpa9eupvbCwkL5+vrq5MmT+vTTT9WrVy9J0saNGzVlyhTt2rVL586dU7NmzdStWze98cYbcnZ21oYNG3TnnXdazOno0aPy8fGx6jxgGSNKAAAAAAAAAAC4isjISB09elQ//vijXn31Vb3++utKSkoqFefn56fFixebta1Zs0b169c3a/v+++8VGRmpzp0767PPPtO3336ruXPnytnZWcXFxWax+/bt09GjR802Ly+vij/JKygqKqrQuJqEQgkAAAAAAAAAAFfh4uIiHx8f+fn5KSoqSuHh4crIyCgVN3ToUK1YsUJ//PGHqW3RokUaOnSoWdzHH38sHx8fzZw5U+3atVPLli0VGRmpN954Q25ubmaxXl5e8vHxMdscHCx/vL9hwwYZDAatW7dOHTp0kKurq7p27ardu3ebxX3xxRe6/fbb5ebmJj8/Pz399NMqKCgw7ff399fUqVMVGxsrd3d3Pfrooxafr1evXoqPj9fo0aPl6empiIgIUw4fffSROnbsKDc3N911113Ky8vThx9+qDZt2sjd3V2PPPKIzp49azpWSUmJkpOTFRAQIDc3NwUFBemdd94p41+k4lAoAQAAAAAAAADACrt379amTZvk7Oxcal9ISIj8/f317rvvSpIOHz6szz77TEOGDDGL8/Hx0dGjR/XZZ59VSo5jx47VrFmztHXrVjVp0kR9+/bV+fPnJUkHDx5UZGSk+vXrp2+++UYrV67UF198ofj4eLNj/POf/1RQUJB27typiRMnlvlcS5culbOzs7788kulpaWZ2idPnqzU1FRt2rRJP//8s/r376+UlBQtX75c69at08cff6y5c+ea4pOTk/XWW28pLS1N3333ncaMGaPBgwdr48aNFXx1zLFGCQAAAAAAAAAAV7F27VrVr19fFy5cUGFhoRwcHJSammoxdvjw4Vq0aJEGDx6sJUuW6N5771WTJk3MYqKjo/XRRx+pZ8+e8vHxUdeuXXX33XebRnBc7sYbbzR73KJFC3333XdXzDcpKUl/+9vfJF0sZNx4441as2aN+vfvr+TkZA0aNEijR4+WJAUGBmrOnDnq2bOn5s+fb1p35a677tIzzzxz1WsTGBiomTNnmh4fPXpUkvTiiy+qe/fukqQRI0YoMTFRBw8e1E033SRJevjhh/Xpp5/q+eefV2FhoaZNm6b//e9/CgsLkyTddNNN+uKLL/T666+rZ8+eV83DVhRKAAAAAAAAAAC4ijvvvFPz589XQUGBXn31VTk5Oalfv34WYwcPHqxx48bpxx9/1JIlSzRnzpxSMY6Ojlq8eLFefPFFffLJJ9q8ebOmTZumGTNmaMuWLWratKkp9vPPP1eDBg1Mj+vUqXPVfC8VGySpUaNGuuWWW7Rnzx5J0tdff61vvvlGy5YtM8UYjUaVlJTo0KFDatOmjSSpc+fOV30e6eIoGks6dOhg+n9vb2/VrVvXVCS51LZlyxZJ0oEDB3T27FlTceeSoqIidezYsVx52IpCCQAAAAAAAAAAV1GvXj3dfPPNki6uORIUFKSFCxdqxIgRpWIbN26s++67TyNGjNC5c+fUu3dv/f777xaP26xZMw0ZMkRDhgzR1KlT1apVK6WlpWnKlCmmmICAADVs2LDCzuXMmTP6v//7Pz399NOl9jVv3tz0//Xq1SvX8cqKu7ygYzAYShV4DAaDSkpKTDlJ0rp169SsWTOzOBcXl3LlYSsKJQAAAAAAAAAAWMHBwUHjx49XQkKCHnnkkVKLr0sXp9+699579fzzz8vR0bFcx73hhhvUtGlTs0XVbfXVV1+Zih6//fabfvjhB9NIkU6dOun77783FX5qgrZt28rFxUWHDx+u1Gm2LKFQAgAAAAAAAACAlaKjozV27FjNmzdPzz77bKn9kZGROnbsWKn1Ri55/fXXtWvXLj344INq2bKlzp07p7feekvfffed2QLnkpSXl6dz586ZtTVu3PiKU3C98MILaty4sby9vfWPf/xDnp6eioqKkiQ9//zz6tq1q+Lj4/X3v/9d9erV0/fff6+MjIwy112pbA0aNNCzzz6rMWPGqKSkRD169NDp06f15Zdfyt3dXUOHDq2056ZQAgAAAAAAAACAlZycnBQfH6+ZM2fq8ccfLzX9lMFgkKenZ5n9u3Tpoi+++EKPPfaYjhw5ovr16+vWW29Venp6qREVt9xyS6n+WVlZ6tq1a5nHnz59ukaNGqX9+/crODhY77//vpydnSVdXDtk48aN+sc//qHbb79dRqNRLVu21IABA6y5BBVu6tSpatKkiZKTk/Xjjz+qYcOG6tSpk8aPH1+pz2swGo3GSn2GKpCfny8PDw+dPn26zOocKp//uHXljs2e3qcSMwEAAAAAAABQE507d06HDh1SQECAXF1dqzudWmnDhg2688479dtvv1XouiY1VVmvKWvqBg6VnSQAAAAAAAAAAEBNRaEEAAAAAAAAAADYLdYoAQAAAAAAAACglujVq5dqwYobVYoRJQAAAAAAAAAAwG7ZVCiZN2+e/P395erqqtDQUG3ZsuWK8atWrVLr1q3l6uqq9u3b64MPPigVs2fPHt1///3y8PBQvXr1dNttt+nw4cO2pAcAAAAAAAAAAFAuVhdKVq5cqYSEBCUlJWnHjh0KCgpSRESE8vLyLMZv2rRJMTExGjFihHbu3KmoqChFRUVp9+7dppiDBw+qR48eat26tTZs2KBvvvlGEydONFuhHgAAAAAAAAAAoKIZjFZOVhYaGqrbbrtNqampkqSSkhL5+fnpqaee0rhx40rFDxgwQAUFBVq7dq2prWvXrgoODlZaWpokaeDAgapTp47efvttm04iPz9fHh4eOn36tNzd3W06Bq6d/7h15Y7Nnt6nEjMBAAAAAAAAUBOdO3dOhw4dUkBAAF+UR4Uo6zVlTd3AqhElRUVF2r59u8LDw/88gIODwsPDlZWVZbFPVlaWWbwkRUREmOJLSkq0bt06tWrVShEREfLy8lJoaKjS09PLzKOwsFD5+flmGwAAAAAAAAAAgLWsKpQcP35cxcXF8vb2Nmv39vZWTk6OxT45OTlXjM/Ly9OZM2c0ffp0RUZG6uOPP9aDDz6ohx56SBs3brR4zOTkZHl4eJg2Pz8/a04DAAAAAAAAAABAko2LuVekkpISSdIDDzygMWPGKDg4WOPGjdN9991nmprrrxITE3X69GnT9vPPP1dlygAAAAAAAAAAoJZwsibY09NTjo6Oys3NNWvPzc2Vj4+PxT4+Pj5XjPf09JSTk5Patm1rFtOmTRt98cUXFo/p4uIiFxcXa1IHAAAAAAAAANRk74+quufqO9uq8GHDhmnp0qX6v//7v1Jf8H/yySf12muvaejQoVqyZInZvqysLPXo0UORkZFat858jefs7GwFBARYfL6srCx17drVqhxhO6tGlDg7OyskJESZmZmmtpKSEmVmZiosLMxin7CwMLN4ScrIyDDFOzs767bbbtO+ffvMYn744Qe1aNHCmvQAAAAAAAAAAKgUfn5+WrFihf744w9T27lz57R8+XI1b97cYp+FCxfqqaee0meffaYjR45YjPnf//6no0ePmm0hISGVcg5lKSoqqtC4643VU28lJCTojTfe0NKlS7Vnzx49/vjjKigoUFxcnCQpNjZWiYmJpvhRo0Zp/fr1mjVrlvbu3avJkydr27Ztio+PN8WMHTtWK1eu1BtvvKEDBw4oNTVV77//vp544okKOEUAAAAAAAAAAK5Np06d5Ofnp9WrV5vaVq9erebNm6tjx46l4s+cOaOVK1fq8ccfV58+fUqNNrmkcePG8vHxMdvq1KljMTY7O1sGg0ErVqxQt27d5Orqqnbt2pVa73v37t3q3bu36tevL29vbw0ZMkTHjx837e/Vq5fi4+M1evRoeXp6KiIiwuLzDRs2TFFRUXrppZfk6+urW265xZTDf/7zH91+++1yc3PTbbfdph9++EFbt25V586dVb9+ffXu3VvHjh0zO96bb76pNm3ayNXVVa1bt9Zrr71m8XmrmtWFkgEDBuif//ynJk2apODgYO3atUvr1683Ldh++PBhHT161BTfrVs3LV++XAsWLFBQUJDeeecdpaenq127dqaYBx98UGlpaZo5c6bat2+vN998U++++6569OhRAacIAAAAAAAAAMC1Gz58uBYvXmx6vGjRItMggr/6z3/+o9atW+uWW27R4MGDtWjRIhmNxgrJY+zYsXrmmWe0c+dOhYWFqW/fvjpx4oQk6dSpU7rrrrvUsWNHbdu2TevXr1dubq769+9vdoylS5fK2dlZX375ZZnrhUtSZmam9u3bp4yMDK1du9bUnpSUpAkTJmjHjh1ycnLSI488oueee06zZ8/W559/rgMHDmjSpEmm+GXLlmnSpEl66aWXtGfPHk2bNk0TJ07U0qVLK+SaXAur1ii5JD4+3mxEyOU2bNhQqi06OlrR0dFXPObw4cM1fPhwW9IBAAAAAAAAAKDSDR48WImJifrpp58kSV9++aVWrFhh8XPxhQsXavDgwZKkyMhInT59Whs3blSvXr3M4rp16yYHB/MxDWfOnLliHvHx8erXr58kaf78+Vq/fr0WLlyo5557TqmpqerYsaOmTZtmil+0aJH8/Pz0ww8/qFWrVpKkwMBAzZw586rnXK9ePb355ptydnaWdHFUiyQ9++yzppEoo0aNUkxMjDIzM9W9e3dJ0ogRI8xG0SQlJWnWrFl66KGHJEkBAQH6/vvv9frrr2vo0KFXzaMy2VQoAQAAAAAAAADA3jRp0sQ0jZbRaFSfPn3k6elZKm7fvn3asmWL1qxZI0lycnLSgAEDtHDhwlKFkpUrV6pNmzZW5XH5muFOTk7q3Lmz9uzZI0n6+uuv9emnn6p+/fql+h08eNBUKCnvOijt27c3FUku16FDB9P/X5pxqn379mZteXl5kqSCggIdPHhQI0aM0MiRI00xFy5ckIeHR7nyqEwUSgAAAAAAAAAAKKfhw4ebZlyaN2+exZiFCxfqwoUL8vX1NbUZjUa5uLgoNTXVrDjg5+enm2++ucLyO3PmjPr27asZM2aU2te0aVPT/9erV69cxysr7vJ1VAwGg8W2kpISU06S9MYbbyg0NNTsOI6OjuXKozJRKAEAAAAAAAAAoJwiIyNVVFQkg8FgcRH0Cxcu6K233tKsWbN0zz33mO2LiorSv//9bz322GPXlMNXX32lO+64w/R827dvNxVvOnXqpHfffVf+/v5ycqoZJQBvb2/5+vrqxx9/1KBBg6o7nVJqxlUCAAAAAAAAAOA64OjoaJrmytJoiLVr1+q3337TiBEjSk0r1a9fPy1cuNCsUHLixAnl5OSYxTVs2FCurq5l5jBv3jwFBgaqTZs2evXVV/Xbb7+Z1gB/8skn9cYbbygmJkbPPfecGjVqpAMHDmjFihV68803q20Ex5QpU/T000/Lw8NDkZGRKiws1LZt2/Tbb78pISGhWnK6xOHqIQAAAAAAAAAA4BJ3d3e5u7tb3Ldw4UKFh4dbXHujX79+2rZtm7755htTW3h4uJo2bWq2paenX/H5p0+frunTpysoKEhffPGF3nvvPdNaKb6+vvryyy9VXFyse+65R+3bt9fo0aPVsGHDUovGV6W///3vevPNN7V48WK1b99ePXv21JIlSxQQEFBtOV1iMBqNxupO4lrl5+fLw8NDp0+fLvPFicrnP25duWOzp/epxEwAAAAAAAAA1ETnzp3ToUOHFBAQcMURE7AsOztbAQEB2rlzp4KDg6s7nRqhrNeUNXUDRpQAAAAAAAAAAAC7RaEEAAAAAAAAAADYLRZzBwAAAAAAAADgOuDv769asJpGjcOIEgAAAAAAAAAAYLcolAAAAAAAAAAAALtFoQQAAAAAAAAAUKVKSkqqOwXUEhXxWmKNEgAAAAAAAABAlXB2dpaDg4OOHDmiJk2ayNnZWQaDobrTwnXIaDSqqKhIx44dk4ODg5ydnW0+FoUSAAAAAAAAAECVcHBwUEBAgI4ePaojR45UdzqoBerWravmzZvLwcH2CbQolAAAAAAAAAAAqoyzs7OaN2+uCxcuqLi4uLrTwXXM0dFRTk5O1zwqiUIJAAAAAAAAAKBKGQwG1alTR3Xq1KnuVAAWcwcAAAAAAAAAAPaLQgkAAAAAAAAAALBbFEoAAAAAAAAAAIDdolACAAAAAAAAAADsFoUSAAAAAAAAAABgtyiUAAAAAAAAAAAAu+VkS6d58+bp5ZdfVk5OjoKCgjR37lx16dKlzPhVq1Zp4sSJys7OVmBgoGbMmKF7773XtH/YsGFaunSpWZ+IiAitX7/elvQAAAAAVCD/cevKFZc9vU8lZwIAAAAAFc/qESUrV65UQkKCkpKStGPHDgUFBSkiIkJ5eXkW4zdt2qSYmBiNGDFCO3fuVFRUlKKiorR7926zuMjISB09etS0/fvf/7btjAAAAAAAAAAAAMrJ6kLJK6+8opEjRyouLk5t27ZVWlqa6tatq0WLFlmMnz17tiIjIzV27Fi1adNGU6dOVadOnZSammoW5+LiIh8fH9N2ww032HZGAAAAAAAAAAAA5WRVoaSoqEjbt29XeHj4nwdwcFB4eLiysrIs9snKyjKLly5Oq/XX+A0bNsjLy0u33HKLHn/8cZ04caLMPAoLC5Wfn2+2AQAAAAAAAAAAWMuqQsnx48dVXFwsb29vs3Zvb2/l5ORY7JOTk3PV+MjISL311lvKzMzUjBkztHHjRvXu3VvFxcUWj5mcnCwPDw/T5ufnZ81pAAAAAAAAAAAASLJxMfeKNnDgQNP/t2/fXh06dFDLli21YcMG3X333aXiExMTlZCQYHqcn59PsQQAAAAAAAAAAFjNqhElnp6ecnR0VG5urll7bm6ufHx8LPbx8fGxKl6SbrrpJnl6eurAgQMW97u4uMjd3d1sAwAAAAAAAAAAsJZVhRJnZ2eFhIQoMzPT1FZSUqLMzEyFhYVZ7BMWFmYWL0kZGRllxkvSL7/8ohMnTqhp06bWpAcAAAAAAAAAAGAVqwolkpSQkKA33nhDS5cu1Z49e/T444+roKBAcXFxkqTY2FglJiaa4keNGqX169dr1qxZ2rt3ryZPnqxt27YpPj5eknTmzBmNHTtWX331lbKzs5WZmakHHnhAN998syIiIiroNAEAAAAAAAAAAEqzeo2SAQMG6NixY5o0aZJycnIUHBys9evXmxZsP3z4sBwc/qy/dOvWTcuXL9eECRM0fvx4BQYGKj09Xe3atZMkOTo66ptvvtHSpUt16tQp+fr66p577tHUqVPl4uJSQacJAAAAAAAAAABQmsFoNBqrO4lrlZ+fLw8PD50+fZr1SqqR/7h15Y7Nnt6nEjMBAABARSrv33n8jQcAAACgprCmbmD11FsAAAAAAAAAAAC1hdVTbwEAAACoWIzMBQAAAIDqw4gSAAAAAAAAAABgtyiUAAAAAAAAAAAAu0WhBAAAAAAAAAAA2C0KJQAAAAAAAAAAwG5RKAEAAAAAAAAAAHaLQgkAAAAAAAAAALBbFEoAAAAAAAAAAIDdolACAAAAAAAAAADsFoUSAAAAAAAAAABgtyiUAAAAAAAAAAAAu+VU3QkAAAAA1vAft67csdnT+1RiJgAAAACA2oARJQAAAAAAAAAAwG5RKAEAAAAAAAAAAHaLQgkAAAAAAAAAALBbFEoAAAAAAAAAAIDdolACAAAAAAAAAADsFoUSAAAAAAAAAABgtyiUAAAAAAAAAAAAu0WhBAAAAAAAAAAA2C0KJQAAAAAAAAAAwG7ZVCiZN2+e/P395erqqtDQUG3ZsuWK8atWrVLr1q3l6uqq9u3b64MPPigz9rHHHpPBYFBKSootqQEAAAAAAAAAAJSb1YWSlStXKiEhQUlJSdqxY4eCgoIUERGhvLw8i/GbNm1STEyMRowYoZ07dyoqKkpRUVHavXt3qdg1a9boq6++kq+vr/VnAgAAAAAAAAAAYCWrCyWvvPKKRo4cqbi4OLVt21ZpaWmqW7euFi1aZDF+9uzZioyM1NixY9WmTRtNnTpVnTp1Umpqqlncr7/+qqeeekrLli1TnTp1bDsbAAAAAAAAAAAAKzhZE1xUVKTt27crMTHR1Obg4KDw8HBlZWVZ7JOVlaWEhASztoiICKWnp5sel5SUaMiQIRo7dqxuvfXWq+ZRWFiowsJC0+P8/HxrTgNALeM/bl254rKn96nkTAAAAAAAAABcb6waUXL8+HEVFxfL29vbrN3b21s5OTkW++Tk5Fw1fsaMGXJyctLTTz9drjySk5Pl4eFh2vz8/Kw5DQAAAAAAAAAAAElWjiipDNu3b9fs2bO1Y8cOGQyGcvVJTEw0G6WSn59PsQQAAAB2pbwjKiVGVQIAAADAlVhVKPH09JSjo6Nyc3PN2nNzc+Xj42Oxj4+PzxXjP//8c+Xl5al58+am/cXFxXrmmWeUkpKi7OzsUsd0cXGRi4uLNakDqGJ8eAMAAAAAAADgemBVocTZ2VkhISHKzMxUVFSUpIvri2RmZio+Pt5in7CwMGVmZmr06NGmtoyMDIWFhUmShgwZovDwcLM+ERERGjJkiOLi4qxJDxWID7lRHXjdAQAAAAAAAKhqVk+9lZCQoKFDh6pz587q0qWLUlJSVFBQYCpqxMbGqlmzZkpOTpYkjRo1Sj179tSsWbPUp08frVixQtu2bdOCBQskSY0bN1bjxo3NnqNOnTry8fHRLbfccq3nBwAAAAAAAAAAUCarCyUDBgzQsWPHNGnSJOXk5Cg4OFjr1683Ldh++PBhOTj8uUZ8t27dtHz5ck2YMEHjx49XYGCg0tPT1a5du4o7CwAAAAAAAAAAABvYtJh7fHx8mVNtbdiwoVRbdHS0oqOjy318S+uSAAAAAAAAAAAAVDSHq4cAAAAAAAAAAADUThRKAAAAAAAAAACA3aJQAgAAAAAAAAAA7BaFEgAAAAAAAAAAYLcolAAAAAAAAAAAALtFoQQAAAAAAAAAANgtCiUAAAAAAAAAAMBuOVV3Aqh8/uPWlSsue3qfSs4EAAAA1am8fxdK1/63IX+DAgAAALheMKIEAAAAAAAAAADYLQolAAAAAAAAAADAblEoAQAAAAAAAAAAdos1SoAaoCrnCwcAAAAAAAAA/IlCCVDBKHoAAAAAAAAAwPWDqbcAAAAAAAAAAIDdolACAAAAAAAAAADsFoUSAAAAAAAAAABgt1ijBLgC1hsBAAAAAAAAgNqNESUAAAAAAAAAAMBuUSgBAAAAAAAAAAB2i6m3AAAAAACwc0w7DAAA7BkjSgAAAAAAAAAAgN1iRAkqzDSnN62I5htIAAAAAAAAAIDqZ9OIknnz5snf31+urq4KDQ3Vli1brhi/atUqtW7dWq6urmrfvr0++OADs/2TJ09W69atVa9ePd1www0KDw/X5s2bbUkNAAAAAAAAAACg3KwulKxcuVIJCQlKSkrSjh07FBQUpIiICOXl5VmM37Rpk2JiYjRixAjt3LlTUVFRioqK0u7du00xrVq1Umpqqr799lt98cUX8vf31z333KNjx47ZfmYAAAAAAAAAAABXYXWh5JVXXtHIkSMVFxentm3bKi0tTXXr1tWiRYssxs+ePVuRkZEaO3as2rRpo6lTp6pTp05KTU01xTzyyCMKDw/XTTfdpFtvvVWvvPKK8vPz9c0339h+ZgAAAAAAAAAAAFdhVaGkqKhI27dvV3h4+J8HcHBQeHi4srKyLPbJysoyi5ekiIiIMuOLioq0YMECeXh4KCgoyGJMYWGh8vPzzTYAAAAAAAAAAABrWVUoOX78uIqLi+Xt7W3W7u3trZycHIt9cnJyyhW/du1a1a9fX66urnr11VeVkZEhT09Pi8dMTk6Wh4eHafPz87PmNAAAAAAAAAAAACTZuJh7Zbjzzju1a9cubdq0SZGRkerfv3+Z654kJibq9OnTpu3nn3+u4mwBAAAAAAAAAEBtYFWhxNPTU46OjsrNzTVrz83NlY+Pj8U+Pj4+5YqvV6+ebr75ZnXt2lULFy6Uk5OTFi5caPGYLi4ucnd3N9sAAAAAAAAAAACs5WRNsLOzs0JCQpSZmamoqChJUklJiTIzMxUfH2+xT1hYmDIzMzV69GhTW0ZGhsLCwq74XCUlJSosLLQmPQAAAAAAANRy/uPWlSsue3qfSs4EAFBbWFUokaSEhAQNHTpUnTt3VpcuXZSSkqKCggLFxcVJkmJjY9WsWTMlJydLkkaNGqWePXtq1qxZ6tOnj1asWKFt27ZpwYIFkqSCggK99NJLuv/++9W0aVMdP35c8+bN06+//qro6OgKPFUAAAAAAAAAAABzVhdKBgwYoGPHjmnSpEnKyclRcHCw1q9fb1qw/fDhw3Jw+HNGr27dumn58uWaMGGCxo8fr8DAQKWnp6tdu3aSJEdHR+3du1dLly7V8ePH1bhxY9122236/PPPdeutt1bQaQIAAAAAYB/K+217iW/cAwAASDYUSiQpPj6+zKm2NmzYUKotOjq6zNEhrq6uWr16tS1pAAAAAAAAAAAAXBOrFnMHAAAAAAAAAACoTSiUAAAAAAAAAAAAu0WhBAAAAAAAAAAA2C0KJQAAAAAAAAAAwG7ZtJg7AACVzX/cunLHZk/vU4mZAAAAVJ/y/k3E30MAAAC2Y0QJAAAAAAAAAACwW4woAQAAAAAAwDVhRDgA4HrGiBIAAAAAAAAAAGC3KJQAAAAAAAAAAAC7RaEEAAAAAAAAAADYLdYoAQA7Vt55hJlDGAAAAAAAALUVhRIAV8WifAAAAAAAAABqKwolAAAAuK5Mc3rTimgK+AAAAACAK6NQAlzHmDYJAIDKxe9aAAAAAKj9KJQAAAAAAADAhOmXcTleDwDsgUN1JwAAAAAAAAAAAFBdKJQAAAAAAAAAAAC7xdRbgJ1hrnVcK4ZdAwAAAAAAoDZhRAkAAAAAAAAAALBbjCgBAAAAAACoYoz2B4DrF7Nt1D4USgAAAFDr8UYGAAAAAFAWCiUAAAAAAABANWKEEQBUL5sKJfPmzdPLL7+snJwcBQUFae7cuerSpUuZ8atWrdLEiROVnZ2twMBAzZgxQ/fee68k6fz585owYYI++OAD/fjjj/Lw8FB4eLimT58uX19f284KAAAAAAAANR4FAgBATWB1oWTlypVKSEhQWlqaQkNDlZKSooiICO3bt09eXl6l4jdt2qSYmBglJyfrvvvu0/LlyxUVFaUdO3aoXbt2Onv2rHbs2KGJEycqKChIv/32m0aNGqX7779f27Ztq5CTBADYD95oAahu3IcAAAAA4PpidaHklVde0ciRIxUXFydJSktL07p167Ro0SKNGzeuVPzs2bMVGRmpsWPHSpKmTp2qjIwMpaamKi0tTR4eHsrIyDDrk5qaqi5duujw4cNq3ry5LecFAAAAAABQ6VgHCwCA659VhZKioiJt375diYmJpjYHBweFh4crKyvLYp+srCwlJCSYtUVERCg9Pb3M5zl9+rQMBoMaNmxocX9hYaEKCwtNj/Pz88t/EgCAKsebRwC4vk1zerOckdzDAQAA7wEBXH+sKpQcP35cxcXF8vb2Nmv39vbW3r17LfbJycmxGJ+Tk2Mx/ty5c3r++ecVExMjd3d3izHJycmaMmWKNakDAAAAAABc15jeEQCAymHTYu6V5fz58+rfv7+MRqPmz59fZlxiYqLZKJX8/Hz5+flVRYoAAABAhSv/iA2JURsAAACAfWB0VtWxqlDi6ekpR0dH5ebmmrXn5ubKx8fHYh8fH59yxV8qkvz000/65JNPyhxNIkkuLi5ycXGxJnUAAIDrXk3/I5lvuQIAAAAArkdWFUqcnZ0VEhKizMxMRUVFSZJKSkqUmZmp+Ph4i33CwsKUmZmp0aNHm9oyMjIUFhZmenypSLJ//359+umnaty4sfVnggrFtxoBAKgdanpxBQAAAACA6mb11FsJCQkaOnSoOnfurC5duiglJUUFBQWKi4uTJMXGxqpZs2ZKTk6WJI0aNUo9e/bUrFmz1KdPH61YsULbtm3TggULJF0skjz88MPasWOH1q5dq+LiYtP6JY0aNZKzs3NFnSuAWooFZgEAAAAA9saW0bx8iQYALLO6UDJgwAAdO3ZMkyZNUk5OjoKDg7V+/XrTgu2HDx+Wg4ODKb5bt25avny5JkyYoPHjxyswMFDp6elq166dJOnXX3/Ve++9J0kKDg42e65PP/1UvXr1svHUAFQnRiUBAFDz8PsZAAAAAEqzaTH3+Pj4Mqfa2rBhQ6m26OhoRUdHW4z39/eX0Wi0JQ0AtQwf3gAAUH783gQAAACAimFToQQAAAAAroSpMQEAAABcLyiUAAAAAAAAAACue7as3QNIksPVQwAAAAAAAAAAAGonRpQAAOwe3zgBqk95f/4kfgYBAAAAAJWDQgkAAABKoYAIAAAAALAXFEoAAAAAAABqKUZvAgBwdRRKAAAAAABAjcZIRwAAUJlYzB0AAAAAAAAAANgtRpQAAAAAAACIaaoAALBXFEpwXWLYNQCgJuH3EgAAAAAA1y8KJQAAoNaigAEAAAAA1YvRehdxHWo2CiV2YJrTm+WMrJ4fwKr6EIsPywAAAAAAAAAAf8Vi7gAAAAAAAAAAwG4xogQAAOAyjEAEri+2TGHAtAcAcH2r6TNnAACuPxRKAAAAANQIFDAAAAAAVAcKJUANUP5vw0h8IwbA9Y4PQgEAAAAAQE1CoQSoYBQ9ANQETB9V81EwAgAAAACgZqBQAlwBRQ8AAAAA9oD3PgAAwJ5RKAEAAAAAAFZjdCQAAKgtKJQAAAAAAAAAAK6KIjlqKwolAAAAAADUUOWfEuvPD6Nq+jRarKV2kS3/tgAAoHJQKAFgl3hzBgAAAMn2b8by9yQAAEDtYVOhZN68eXr55ZeVk5OjoKAgzZ07V126dCkzftWqVZo4caKys7MVGBioGTNm6N577zXtX716tdLS0rR9+3adPHlSO3fuVHBwsC2pAagENf0babUNw1hRXfjABwAA1Cb8XQ0AAMrL6kLJypUrlZCQoLS0NIWGhiolJUURERHat2+fvLy8SsVv2rRJMTExSk5O1n333afly5crKipKO3bsULt27SRJBQUF6tGjh/r376+RI0de+1kBdoKh2lWLN1pAaRRXAMD+cO+3HdcOAACgZrK6UPLKK69o5MiRiouLkySlpaVp3bp1WrRokcaNG1cqfvbs2YqMjNTYsWMlSVOnTlVGRoZSU1OVlpYmSRoyZIgkKTs729bzAFBOFFeqFsUVAAAA21BUAK4vzEQAALieWVUoKSoq0vbt25WYmGhqc3BwUHh4uLKysiz2ycrKUkJCgllbRESE0tPTrc/2/yssLFRhYaHpcX5+vs3HAgAANR9FRwAAAABAZeDLGZCsLJQcP35cxcXF8vb2Nmv39vbW3r17LfbJycmxGJ+Tk2Nlqn9KTk7WlClTbO4PALURv9gBAAAAAAAA69m0mHt1S0xMNBulkp+fLz8/v2rMCAAAAABQ3RiBCFSMqpxGi+mhAQA1gVWFEk9PTzk6Oio3N9esPTc3Vz4+Phb7+Pj4WBVfHi4uLnJxcbG5PwAAAACg6jDyFbUZRQXUdqw/A8AeWFUocXZ2VkhIiDIzMxUVFSVJKikpUWZmpuLj4y32CQsLU2ZmpkaPHm1qy8jIUFhYmM1JAwBqP1v/GOfNIwAAAIC/smXEWVX1AQBUP6un3kpISNDQoUPVuXNndenSRSkpKSooKFBcXJwkKTY2Vs2aNVNycrIkadSoUerZs6dmzZqlPn36aMWKFdq2bZsWLFhgOubJkyd1+PBhHTlyRJK0b98+SRdHo1zLyBMAQM1Q07+BRHEFAAAAtVVN/1scF/GeBACql9WFkgEDBujYsWOaNGmScnJyFBwcrPXr15sWbD98+LAcHBxM8d26ddPy5cs1YcIEjR8/XoGBgUpPT1e7du1MMe+9956p0CJJAwcOlCQlJSVp8uTJtp4bAOAq+GMcAAAAAK5PvJ8DgIpj02Lu8fHxZU61tWHDhlJt0dHRio6OLvN4w4YN07Bhw2xJBQBQxfhGGgAAlYtpWwAAQFWojWuIVVUBkc9Gah+bCiUAAACoerVxsdiqOifWPQIAAKj5f3vZojZ+2F8bz8lafHEEVY1CCQAAVYQ/9KoeH3LXzuJKTcd1wPWC30sAAAA1GyNXqg6FEgAAbFBVHy7xIRYAAAAAXJ94PwdcPyiUAAAAAAAAAKh2VTXlFFNbAfgrCiWodrZMT1FVfQAAAFB1bJlaoKr6AAAA1GQUf4Br41DdCQAAAAAAAAAAAFQXRpQAAAAAAAAAwBWw3ghQu1EoAQAAAAAANVpVTaXM1HwAgMrC75iajUIJAADANWIdLACoeDX9g3Hu/QAAlA8FAlwPKJQAQCXijwEAAOxbbVxYlQKB7bh2AIDrHZ9zoLaiUAIAtQRvvIGKUdO/wQyg+tTkokdNnzede2vtxGgcAABQW1AoAWCXbHlzVlVvvHmDDwAAAFw7/q4GAADlRaEEAADYjG+EAgAAAACuZ7yvhUShBAAAXAf4RigAAACuhS1/T1ZVHwCl8bOEqkahBACAWqgmfyOGP3gBAAAAoPrV5PeNQFWjUAIAgA1q47fL+CMZAAAAAKof782AqkehBACAGqymF1cAAAAA4HpTVYUI3s8B1w8KJQAAAABQg1TVhyp8eAMAqGlsKWAw+uIirgNwbRyqOwEAAAAAAAAAAIDqwogSAAAAAAAAALgCRmICtRsjSgAAAAAAAAAAgN1iRAkAAAAAlENNnjedb7kCAAAAtrNpRMm8efPk7+8vV1dXhYaGasuWLVeMX7VqlVq3bi1XV1e1b99eH3zwgdl+o9GoSZMmqWnTpnJzc1N4eLj2799vS2oAAAAAAAAAAADlZnWhZOXKlUpISFBSUpJ27NihoKAgRUREKC8vz2L8pk2bFBMToxEjRmjnzp2KiopSVFSUdu/ebYqZOXOm5syZo7S0NG3evFn16tVTRESEzp07Z/uZAQAAAAAAAAAAXIXVhZJXXnlFI0eOVFxcnNq2bau0tDTVrVtXixYtshg/e/ZsRUZGauzYsWrTpo2mTp2qTp06KTU1VdLF0SQpKSmaMGGCHnjgAXXo0EFvvfWWjhw5ovT09Gs6OQAAAAAAAAAAgCuxao2SoqIibd++XYmJiaY2BwcHhYeHKysry2KfrKwsJSQkmLVFRESYiiCHDh1STk6OwsPDTfs9PDwUGhqqrKwsDRw4sNQxCwsLVVhYaHp8+vRpSVJ+fr41p2M3zhaeL1fc5devvH0u72dLH2v61eQ+l/ez52t3eT+uA9fh8n5cB64Dv2Ns73N5P3u+dpf34zpwHS7vx3XgOvA7xvY+l/ez52t3eT+uA9fh8n5cB64Dv2Ns73N5P3u+dpf3q8rrgIsuXROj0Xj1YKMVfv31V6Mk46ZNm8zax44da+zSpYvFPnXq1DEuX77crG3evHlGLy8vo9FoNH755ZdGScYjR46YxURHRxv79+9v8ZhJSUlGSWxsbGxsbGxsbGxsbGxsbGxsbGxsbGxsbGVuP//881VrH1aNKKkpEhMTzUaplJSU6OTJk2rcuLEMBkM1ZnZ9yM/Pl5+fn37++We5u7tXdzoAahDuDwDKwv0BQFm4PwAoC/cHAGXh/oCqYDQa9fvvv8vX1/eqsVYVSjw9PeXo6Kjc3Fyz9tzcXPn4+Fjs4+Pjc8X4S//Nzc1V06ZNzWKCg4MtHtPFxUUuLi5mbQ0bNrTmVCDJ3d2dGxEAi7g/ACgL9wcAZeH+AKAs3B8AlIX7Ayqbh4dHueKsWszd2dlZISEhyszMNLWVlJQoMzNTYWFhFvuEhYWZxUtSRkaGKT4gIEA+Pj5mMfn5+dq8eXOZxwQAAAAAAAAAAKgIVk+9lZCQoKFDh6pz587q0qWLUlJSVFBQoLi4OElSbGysmjVrpuTkZEnSqFGj1LNnT82aNUt9+vTRihUrtG3bNi1YsECSZDAYNHr0aL344osKDAxUQECAJk6cKF9fX0VFRVXcmQIAAAAAAAAAAPyF1YWSAQMG6NixY5o0aZJycnIUHBys9evXy9vbW5J0+PBhOTj8OVClW7duWr58uSZMmKDx48crMDBQ6enpateunSnmueeeU0FBgR599FGdOnVKPXr00Pr16+Xq6loBp4i/cnFxUVJSUqnpywCA+wOAsnB/AFAW7g8AysL9AUBZuD+gpjEYjUZjdScBAAAAAAAAAABQHaxaowQAAAAAAAAAAKA2oVACAAAAAAAAAADsFoUSAAAAAAAAAABgtyiUAAAAAAAAAAAAu0WhxA7NmzdP/v7+cnV1VWhoqLZs2VLdKQGoQsnJybrtttvUoEEDeXl5KSoqSvv27TOLOXfunJ588kk1btxY9evXV79+/ZSbm1tNGQOoLtOnT5fBYNDo0aNNbdwfAPv166+/avDgwWrcuLHc3NzUvn17bdu2zbTfaDRq0qRJatq0qdzc3BQeHq79+/dXY8YAqkJxcbEmTpyogIAAubm5qWXLlpo6daqMRqMphvsDYB8+++wz9e3bV76+vjIYDEpPTzfbX557wcmTJzVo0CC5u7urYcOGGjFihM6cOVOFZwF7RaHEzqxcuVIJCQlKSkrSjh07FBQUpIiICOXl5VV3agCqyMaNG/Xkk0/qq6++UkZGhs6fP6977rlHBQUFppgxY8bo/fff16pVq7Rx40YdOXJEDz30UDVmDaCqbd26Va+//ro6dOhg1s79AbBPv/32m7p37646deroww8/1Pfff69Zs2bphhtuMMXMnDlTc+bMUVpamjZv3qx69eopIiJC586dq8bMAVS2GTNmaP78+UpNTdWePXs0Y8YMzZw5U3PnzjXFcH8A7ENBQYGCgoI0b948i/vLcy8YNGiQvvvuO2VkZGjt2rX67LPP9Oijj1bVKcCOGYyXl/hR64WGhuq2225TamqqJKmkpER+fn566qmnNG7cuGrODkB1OHbsmLy8vLRx40bdcccdOn36tJo0aaLly5fr4YcfliTt3btXbdq0UVZWlrp27VrNGQOobGfOnFGnTp302muv6cUXX1RwcLBSUlK4PwB2bNy4cfryyy/1+eefW9xvNBrl6+urZ555Rs8++6wk6fTp0/L29taSJUs0cODAqkwXQBW677775O3trYULF5ra+vXrJzc3N/3rX//i/gDYKYPBoDVr1igqKkpS+f5W2LNnj9q2bautW7eqc+fOkqT169fr3nvv1S+//CJfX9/qOh3YAUaU2JGioiJt375d4eHhpjYHBweFh4crKyurGjMDUJ1Onz4tSWrUqJEkafv27Tp//rzZvaJ169Zq3rw59wrATjz55JPq06eP2X1A4v4A2LP33ntPnTt3VnR0tLy8vNSxY0e98cYbpv2HDh1STk6O2f3Bw8NDoaGh3B+AWq5bt27KzMzUDz/8IEn6+uuv9cUXX6h3796SuD8AuKg894KsrCw1bNjQVCSRpPDwcDk4OGjz5s1VnjPsi1N1J4Cqc/z4cRUXF8vb29us3dvbW3v37q2mrABUp5KSEo0ePVrdu3dXu3btJEk5OTlydnZWw4YNzWK9vb2Vk5NTDVkCqEorVqzQjh07tHXr1lL7uD8A9uvHH3/U/PnzlZCQoPHjx2vr1q16+umn5ezsrKFDh5ruAZbea3B/AGq3cePGKT8/X61bt5ajo6OKi4v10ksvadCgQZLE/QGApPLdC3JycuTl5WW238nJSY0aNeJ+gUpHoQQA7NiTTz6p3bt364svvqjuVADUAD///LNGjRqljIwMubq6Vnc6AGqQkpISde7cWdOmTZMkdezYUbt371ZaWpqGDh1azdkBqE7/+c9/tGzZMi1fvly33nqrdu3apdGjR8vX15f7AwDgusHUW3bE09NTjo6Oys3NNWvPzc2Vj49PNWUFoLrEx8dr7dq1+vTTT3XjjTea2n18fFRUVKRTp06ZxXOvAGq/7du3Ky8vT506dZKTk5OcnJy0ceNGzZkzR05OTvL29ub+ANippk2bqm3btmZtbdq00eHDhyXJdA/gvQZgf8aOHatx48Zp4MCBat++vYYMGaIxY8YoOTlZEvcHABeV517g4+OjvLw8s/0XLlzQyZMnuV+g0lEosSPOzs4KCQlRZmamqa2kpESZmZkKCwurxswAVCWj0aj4+HitWbNGn3zyiQICAsz2h4SEqE6dOmb3in379unw4cPcK4Ba7u6779a3336rXbt2mbbOnTtr0KBBpv/n/gDYp+7du2vfvn1mbT/88INatGghSQoICJCPj4/Z/SE/P1+bN2/m/gDUcmfPnpWDg/nHS46OjiopKZHE/QHAReW5F4SFhenUqVPavn27KeaTTz5RSUmJQkNDqzxn2Bem3rIzCQkJGjp0qDp37qwuXbooJSVFBQUFiouLq+7UAFSRJ598UsuXL9d///tfNWjQwDTPp4eHh9zc3OTh4aERI0YoISFBjRo1kru7u5566imFhYWpa9eu1Zw9gMrUoEED03pFl9SrV0+NGzc2tXN/AOzTmDFj1K1bN02bNk39+/fXli1btGDBAi1YsECSZDAYNHr0aL344osKDAxUQECAJk6cKF9fX0VFRVVv8gAqVd++ffXSSy+pefPmuvXWW7Vz50698sorGj58uCTuD4A9OXPmjA4cOGB6fOjQIe3atUuNGjVS8+bNr3ovaNOmjSIjIzVy5EilpaXp/Pnzio+P18CBA+Xr61tNZwV7YTAajcbqTgJVKzU1VS+//LJycnIUHBysOXPmUJUF7IjBYLDYvnjxYg0bNkySdO7cOT3zzDP697//rcLCQkVEROi1115jqCtgh3r16qXg4GClpKRI4v4A2LO1a9cqMTFR+/fvV0BAgBISEjRy5EjTfqPRqKSkJC1YsECnTp1Sjx499Nprr6lVq1bVmDWAyvb7779r4sSJWrNmjfLy8uTr66uYmBhNmjRJzs7Okrg/APZiw4YNuvPOO0u1Dx06VEuWLCnXveDkyZOKj4/X+++/LwcHB/Xr109z5sxR/fr1q/JUYIcolAAAAAAAAAAAALvFGiUAAAAAAAAAAMBuUSgBAAAAAAAAAAB2i0IJAAAAAAAAAACwWxRKAAAAAAAAAACA3aJQAgAAAAAAAAAA7BaFEgAAAAAAAAAAYLcolAAAAAAAAAAAALtFoQQAAAAAAAAAANgtCiUAAAAAAAAAAMBuUSgBAAAAAAAAAAB2i0IJAAAAAAAAAACwWxRKAAAAAAAAAACA3fp/dsddEtLpgpQAAAAASUVORK5CYII=",
+ "text/plain": [
+ "