From 0004728799fa1048d87d785ea06cce59c609f041 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:51:26 -0600 Subject: [PATCH] Add pytorch lightning trainer (#99) * initial attempt of lightning training interface * fix train_step and remove print * fix batch order for end validation epoch * fix types * fix raisebatchsize for lightning * remember to detach tensors * add valid tag to lr scheduler * add loss printing and controller * add dataloader args for additional configuration * refactor slightly and fix type errors * add extra dataloader args to test script * closer to connecting controller to lightning module * guard print statements * prevent double-updating of schedulers * add sanity checking guard * fix printing in sanity check * get batch size changes working with pytorch lightning * make sure custom kernels don't automatically trigger cuda context on device 0 * adding saving of modules (very necessary!) * make lightning trainer not have to serialize constantly * add coalescing custom kernel call for hip-nn-ts (l=2) * add coalescing custom kernel call to hip-nn-ts, l=1 * formating and debug print * update packages in docs * make lightning import optional * make metric tracker only seek better metrics on validation * Make controller and metric tracker see metrics reduced across nodes * update lightning test script * update docs and requirements, remove extraneous code * good old fashioned formatting --------- Co-authored-by: Nicholas Lubbers --- conda_requirements.txt | 1 + docs/source/conf.py | 4 +- docs/source/installation.rst | 17 +- docs/source/user_guide/settings.rst | 2 +- examples/barebones_lightning.py | 102 +++++ hippynn/_settings_setup.py | 30 +- hippynn/custom_kernels/tensor_wrapper.py | 8 +- hippynn/custom_kernels/test_env_numba.py | 7 + hippynn/databases/__init__.py | 9 + hippynn/databases/database.py | 38 +- hippynn/experiment/__init__.py | 9 +- hippynn/experiment/controllers.py | 56 ++- hippynn/experiment/lightning_trainer.py | 369 ++++++++++++++++++ hippynn/experiment/metric_tracker.py | 31 +- hippynn/experiment/routines.py | 5 +- hippynn/experiment/serialization.py | 14 +- hippynn/graphs/gops.py | 3 +- .../interfaces/ase_interface/ase_database.py | 12 +- hippynn/layers/hiplayers.py | 62 ++- hippynn/pretraining.py | 2 +- hippynn/tools.py | 19 +- setup.py | 1 + tests/lightning_QM7_test.py | 219 +++++++++++ 23 files changed, 912 insertions(+), 108 deletions(-) create mode 100644 examples/barebones_lightning.py create mode 100644 hippynn/experiment/lightning_trainer.py create mode 100644 tests/lightning_QM7_test.py diff --git a/conda_requirements.txt b/conda_requirements.txt index 590e5ca3..f8f1f391 100644 --- a/conda_requirements.txt +++ b/conda_requirements.txt @@ -8,3 +8,4 @@ ase h5py tqdm python-graphviz +lightning \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 9a71efed..a47dfe54 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,7 @@ project = "hippynn" copyright = "2019, Los Alamos National Laboratory" -author = "Nicholas Lubbers" +author = "Nicholas Lubbers et al" # The full version, including alpha/beta/rc tags import hippynn @@ -47,7 +47,7 @@ } # The following are highly optional, so we mock them for doc purposes. -autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba"] +autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning"] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 54384e44..4064fea9 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -10,16 +10,18 @@ Requirements: * Python_ >= 3.9 * pytorch_ >= 1.9 * numpy_ + Optional Dependencies: * triton_ (recommended, for improved GPU performance) * numba_ (recommended for improved CPU performance) - * cupy_ (Alternative for accelerating GPU performance) - * ASE_ (for usage with ase) + * cupy_ (alternative for accelerating GPU performance) + * ASE_ (for usage with ase and other misc. features) * matplotlib_ (for plotting) * tqdm_ (for progress bars) - * graphviz_ (for viewing model graphs as figures) + * graphviz_ (for visualizing model graphs) * h5py_ (for loading ani-h5 datasets) * pyanitools_ (for loading ani-h5 datasets) + * pytorch-lightning_ (for distributed training) Interfacing codes: * ASE_ @@ -40,7 +42,7 @@ Interfacing codes: .. _ASE: https://wiki.fysik.dtu.dk/ase/ .. _LAMMPS: https://www.lammps.org/ .. _PYSEQM: https://github.com/lanl/PYSEQM - +.. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning Installation Instructions ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -67,9 +69,6 @@ Clone the hippynn_ repository and navigate into it, e.g.:: .. _hippynn: https://github.com/lanl/hippynn/ -.. note:: - If you wish to do a cpu-only install, you may need to comment - out ``cupy`` from the conda_requirements.txt file. Dependencies using conda ........................ @@ -78,6 +77,10 @@ Install dependencies from conda using recommended channels:: $ conda install -c pytorch -c conda-forge --file conda_requirements.txt +.. note:: + If you wish to do a cpu-only install, you may need to comment + out ``cupy`` from the conda_requirements.txt file. + Dependencies using pip ....................... diff --git a/docs/source/user_guide/settings.rst b/docs/source/user_guide/settings.rst index d8657de4..c6764206 100644 --- a/docs/source/user_guide/settings.rst +++ b/docs/source/user_guide/settings.rst @@ -31,7 +31,7 @@ The following settings are available: - Dynamic * - PROGRESS - Progress bars function during training, evaluation, and prediction - - tqdm, none + - tqdm, none, or floating point string specifying default update rate in seconds (default 1). - tqdm - Yes, but assign this to a generator-wrapper such as ``tqdm.tqdm``, or with a python ``None`` to disable. The wrapper must accept ``tqdm`` arguments, although it technically doesn't have to do anything with them. * - DEFAULT_PLOT_FILETYPE diff --git a/examples/barebones_lightning.py b/examples/barebones_lightning.py new file mode 100644 index 00000000..4469d7ed --- /dev/null +++ b/examples/barebones_lightning.py @@ -0,0 +1,102 @@ +''' +To obtain the data files needed for this example, use the script process_QM7_data.py, +also located in this folder. The script contains further instructions for use. +''' + +import torch + +# Setup pytorch things +torch.set_default_dtype(torch.float32) + +import hippynn + +netname = "TEST_BAREBONES_LIGHTNING_SCRIPT" + +# Hyperparameters for the network +# These are set deliberately small so that you can easily run the example on a laptop or similar. +network_params = { + "possible_species": [0, 1, 6, 7, 8, 16], # Z values of the elements in QM7 + "n_features": 20, # Number of neurons at each layer + "n_sensitivities": 20, # Number of sensitivity functions in an interaction layer + "dist_soft_min": 1.6, # qm7 is in Bohr! + "dist_soft_max": 10.0, + "dist_hard_max": 12.5, + "n_interaction_layers": 2, # Number of interaction blocks + "n_atom_layers": 3, # Number of atom layers in an interaction block +} + +# Define a model +from hippynn.graphs import inputs, networks, targets, physics + +species = inputs.SpeciesNode(db_name="Z") +positions = inputs.PositionsNode(db_name="R") + +network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params) +henergy = targets.HEnergyNode("HEnergy", network, db_name="T") +# hierarchicality = henergy.hierarchicality + +# define loss quantities +from hippynn.graphs import loss + +mse_energy = loss.MSELoss.of_node(henergy) +mae_energy = loss.MAELoss.of_node(henergy) +rmse_energy = mse_energy ** (1 / 2) + +# Validation losses are what we check on the data between epochs -- we can only train to +# a single loss, but we can check other metrics too to better understand how the model is training. +# There will also be plots of these things over time when training completes. +validation_losses = { + "RMSE": rmse_energy, + "MAE": mae_energy, + "MSE": mse_energy, +} + +# This piece of code glues the stuff together as a pytorch model, +# dropping things that are irrelevant for the losses defined. +training_modules, db_info = hippynn.experiment.assemble_for_training(mse_energy, validation_losses) + +# Go to a directory for the model. +# hippynn will save training files in the current working directory. +with hippynn.tools.active_directory(netname): + # Log the output of python to `training_log.txt` + with hippynn.tools.log_terminal("training_log.txt", "wt"): + database = hippynn.databases.DirectoryDatabase( + name="data-qm7", # Prefix for arrays in the directory + directory="../../../datasets/qm7_processed", + test_size=0.1, # Fraction or number of samples to test on + valid_size=0.1, # Fraction or number of samples to validate on + seed=2001, # Random seed for splitting data + **db_info, # Adds the inputs and targets db_names from the model as things to load + dataloader_kwargs=dict(persistent_workers=True,multiprocessing_context='fork'), + num_workers=2, + ) + + # Now that we have a database and a model, we can + # Fit the non-interacting energies by examining the database. + # This tends to stabilize training a lot. + from hippynn.pretraining import hierarchical_energy_initialization + + hierarchical_energy_initialization(henergy, database, trainable_after=False) + + # Parameters describing the training procedure. + from hippynn.experiment import setup_and_train + + experiment_params = hippynn.experiment.SetupParams( + stopping_key="MSE", # The name in the validation_losses dictionary. + batch_size=12, + optimizer=torch.optim.Adam, + max_epochs=100, + learning_rate=0.001, + ) + # setup_and_train( + # training_modules=training_modules, + # database=database, + # setup_params=experiment_params, + # ) + from hippynn.experiment import HippynnLightningModule + +# lightning needs to run exactly where the script is located in distributed modes. +lightmod, datamodule = HippynnLightningModule.from_experiment_setup(training_modules, database, experiment_params) +import pytorch_lightning as pl +trainer = pl.Trainer(accelerator='cpu') #'auto' detects MPS which doesn't work. +trainer.fit(model=lightmod, datamodule=datamodule) diff --git a/hippynn/_settings_setup.py b/hippynn/_settings_setup.py index 62b872c9..c72eb94f 100644 --- a/hippynn/_settings_setup.py +++ b/hippynn/_settings_setup.py @@ -29,16 +29,22 @@ TQDM_PROGRESS = None if TQDM_PROGRESS is not None: - TQDM_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False) - + DEFAULT_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False) +else: + DEFAULT_PROGRESS = None ### Progress handlers - def progress_handler(prog_str): if prog_str == "tqdm": - return TQDM_PROGRESS - if prog_str.lower() == "none": + return DEFAULT_PROGRESS + elif prog_str.lower() == "none": return None + else: + try: + prog_float = float(prog_str) + return partial(TQDM_PROGRESS, mininterval=prog_float, leave=False) + except: + pass warnings.warn(f"Unrecognized progress setting: '{prog_str}'. Setting to none.") @@ -63,7 +69,7 @@ def kernel_handler(kernel_string): # keys: defaults, types, and handlers default_settings = { - "PROGRESS": (TQDM_PROGRESS, progress_handler), + "PROGRESS": (DEFAULT_PROGRESS, progress_handler), "DEFAULT_PLOT_FILETYPE": (".pdf", str), "TRANSPARENT_PLOT": (False, strtobool), "DEBUG_LOSS_BROADCAST": (False, strtobool), @@ -85,11 +91,16 @@ def kernel_handler(kernel_string): config_sources = {} # Dictionary of configuration variable sources mapping to dictionary of configuration. # We add to this dictionary in order of application +SECTION_NAME = "GLOBALS" + rc_name = os.path.expanduser("~/.hippynnrc") if os.path.exists(rc_name) and os.path.isfile(rc_name): config = configparser.ConfigParser(inline_comment_prefixes="#") config.read(rc_name) - config_sources["~/.hippynnrc"] = config["GLOBALS"] + if SECTION_NAME not in config: + warnings.warn(f"Config file {rc_name} does not contain a {SECTION_NAME} section and will be ignored!") + else: + config_sources["~/.hippynnrc"] = config[SECTION_NAME] SETTING_PREFIX = "HIPPYNN_" hippynn_environment_variables = { @@ -103,7 +114,10 @@ def kernel_handler(kernel_string): if os.path.exists(local_rc_fname) and os.path.isfile(local_rc_fname): local_config = configparser.ConfigParser() local_config.read(local_rc_fname) - config_sources[LOCAL_RC_FILE_KEY] = local_config["GLOBALS"] + if SECTION_NAME not in local_config: + warnings.warn(f"Config file {local_rc_fname} does not contain a {SECTION_NAME} section and will be ignored!") + else: + config_sources[LOCAL_RC_FILE_KEY] = local_config[SECTION_NAME] else: warnings.warn(f"Local configuration file {local_rc_fname} not found.") diff --git a/hippynn/custom_kernels/tensor_wrapper.py b/hippynn/custom_kernels/tensor_wrapper.py index ade8ddcf..6323b61e 100644 --- a/hippynn/custom_kernels/tensor_wrapper.py +++ b/hippynn/custom_kernels/tensor_wrapper.py @@ -38,8 +38,8 @@ def _numba_gpu_not_found(*args, **kwargs): class NumbaCompatibleTensorFunction: def __init__(self): if numba.cuda.is_available(): - self.kernel64 = self.make_kernel(numba.float64) - self.kernel32 = self.make_kernel(numba.float32) + self.kernel64 = None + self.kernel32 = None else: self.kernel64 = _numba_gpu_not_found self.kernel32 = _numba_gpu_not_found @@ -59,8 +59,12 @@ def __call__(self, *args, **kwargs): with numba.cuda.gpus[dev.index]: numba_args = batch_convert_torch_to_numba(*args) if dtype == torch.float64: + if self.kernel64 is None: + self.kernel64 = self.make_kernel(numba.float64) self.kernel64[launch_bounds](*numba_args) elif dtype == torch.float32: + if self.kernel32 is None: + self.kernel32 = self.make_kernel(numba.float32) self.kernel32[launch_bounds](*numba_args) else: raise ValueError("Bad dtype: {}".format(dtype)) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index d9a117c1..616a2eb8 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -122,6 +122,7 @@ def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printi TEST_LARGE_PARAMS = dict(n_molecules=1000, n_atoms=30, atom_prob=0.7, n_features=80, n_nu=20) TEST_MEGA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=100) TEST_ULTRA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=320) +TEST_GIGA_PARAMS = dict(n_molecules=32, n_atoms=30, atom_prob=0.7, n_features=512, n_nu=320) # reference implementation @@ -434,6 +435,12 @@ def main(env_impl, sense_impl, feat_impl, args=None): if use_verylarge_gpu: if use_ultra: + + print("-" * 80) + print("Giga systems:", TEST_GIGA_PARAMS) + tester.check_speed( + n_repetitions=20, data_size=TEST_GIGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against + ) print("-" * 80) print("Ultra systems:", TEST_ULTRA_PARAMS) tester.check_speed( diff --git a/hippynn/databases/__init__.py b/hippynn/databases/__init__.py index d938fd54..e97ad715 100644 --- a/hippynn/databases/__init__.py +++ b/hippynn/databases/__init__.py @@ -12,16 +12,25 @@ from .database import Database from .ondisk import DirectoryDatabase, NPZDatabase has_ase = False +has_h5 = False + try: import ase has_ase = True + import h5py + has_h5 = True except ImportError: pass if has_ase: from ..interfaces.ase_interface import AseDatabase + if has_h5: + from .h5_pyanitools import PyAniFileDB, PyAniDirectoryDB all_list = ["Database", "DirectoryDatabase", "NPZDatabase"] + if has_ase: all_list += ["AseDatabase"] + if has_h5: + all_list += ["PyAniFileDB", "PyAniDirectoryDB"] __all__ = all_list diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index 19acbb52..fa503763 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -1,6 +1,8 @@ """ Base database functionality from dictionary of numpy arrays """ + +from typing import Union import warnings import numpy as np import torch @@ -20,17 +22,18 @@ class Database: def __init__( self, - arr_dict, - inputs, - targets, - seed, - test_size=None, - valid_size=None, - num_workers=0, - pin_memory=True, - allow_unfound=False, - auto_split=False, - device=None, + arr_dict: dict[str,torch.Tensor], + inputs: list[str], + targets: list[str], + seed: [int,np.random.RandomState,tuple], + test_size: Union[float,int]=None, + valid_size: Union[float,int]=None, + num_workers: int=0, + pin_memory: bool=True, + allow_unfound:bool =False, + auto_split:bool =False, + device: torch.device=None, + dataloader_kwargs:dict[str,object]=None, quiet=False, ): """ @@ -47,6 +50,9 @@ def __init__( :param allow_unfound: If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None. :param auto_split: If true, look for keys like "split_*" to make initial splits from. See write_npz() method. + :param device: if set, move the dataset to this device after splitting. + :param dataloader_kwargs: dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. + Refer to pytorch documentation for details. :param quiet: If True, print little or nothing while loading. """ @@ -123,6 +129,8 @@ def __init__( else: self.send_to_device(device) + self.dataloader_kwargs = dataloader_kwargs.copy() if dataloader_kwargs else {} + def __len__(self): return arrdict_len(self.arr_dict) @@ -425,6 +433,7 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample shuffle=shuffle, pin_memory=self.pin_memory, num_workers=self.num_workers, + **self.dataloader_kwargs, ) return generator @@ -514,7 +523,7 @@ def write_h5(self, split=None, h5path=None, species_key='species', overwrite=Fal return write_h5_function(self, split=split, file=h5path, species_key=species_key, overwrite=overwrite) - def write_npz(self, file: str, record_split_masks: bool = True, overwrite: bool = False, split_prefix=None, return_only=False): + def write_npz(self, file: str, record_split_masks: bool = True, compressed:bool =True, overwrite: bool = False, split_prefix=None, return_only=False): """ :param file: str, Path, or file object compatible with np.save :param record_split_masks: @@ -561,7 +570,10 @@ def write_npz(self, file: str, record_split_masks: bool = True, overwrite: bool if file.exists() and not overwrite: raise FileExistsError(f"File exists: {file}") - np.savez_compressed(file, **arr_dict) + if compressed: + np.savez_compressed(file, **arr_dict) + else: + np.savez(file, **arr_dict) return arr_dict diff --git a/hippynn/experiment/__init__.py b/hippynn/experiment/__init__.py index e31cb597..3a222e9b 100644 --- a/hippynn/experiment/__init__.py +++ b/hippynn/experiment/__init__.py @@ -13,4 +13,11 @@ from .assembly import assemble_for_training from .routines import setup_and_train, setup_training, train_model, test_model, SetupParams -__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams"] + +__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams",] + +try: + from .lightning_trainer import HippynnLightningModule + __all__ += ["HippynnLightningModule"] +except ImportError: + pass diff --git a/hippynn/experiment/controllers.py b/hippynn/experiment/controllers.py index 125e168a..dc88e77f 100644 --- a/hippynn/experiment/controllers.py +++ b/hippynn/experiment/controllers.py @@ -6,7 +6,6 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau - class Controller: """ Class for controlling the training dynamics. @@ -51,12 +50,10 @@ def __init__( fraction_train_eval=0.1, quiet=False, ): + super().__init__() self.optimizer = optimizer - self.scheduler = scheduler - self.stopping_key = stopping_key - self.batch_size = batch_size self.eval_batch_size = eval_batch_size or batch_size if max_epochs is None: @@ -85,7 +82,8 @@ def __init__( def state_dict(self): state_dict = {k: getattr(self, k) for k in self._state_vars} - state_dict["optimizer"] = self.optimizer.state_dict() + if self.optimizer is not None: + state_dict["optimizer"] = self.optimizer.state_dict() state_dict["scheduler"] = [sch.state_dict() for sch in self.scheduler_list] return state_dict @@ -94,7 +92,8 @@ def load_state_dict(self, state_dict): for sch, sdict in zip(self.scheduler_list, state_dict["scheduler"]): sch.load_state_dict(sdict) - self.optimizer.load_state_dict(state_dict["optimizer"]) + if self.optimizer is not None: + self.optimizer.load_state_dict(state_dict["optimizer"]) for k in self._state_vars: setattr(self, k, state_dict[k]) @@ -103,7 +102,7 @@ def load_state_dict(self, state_dict): def max_epochs(self): return self._max_epochs - def push_epoch(self, epoch, better_model, metric): + def push_epoch(self, epoch, better_model, metric, _print=print): self.current_epoch += 1 if better_model: @@ -118,8 +117,9 @@ def push_epoch(self, epoch, better_model, metric): sch.step() if not self.quiet: - print("Epochs since last best:", self.boredom) - print("Current max epochs:", self.max_epochs) + _print("Epochs since last best:", self.boredom) + _print("Current max epochs:", self.max_epochs) + return self.current_epoch < self.max_epochs @@ -139,23 +139,27 @@ def __init__(self, *args, termination_patience, **kwargs): self.patience = termination_patience self.last_best = 0 - def push_epoch(self, epoch, better_model, metric): + def push_epoch(self, epoch, better_model, metric, _print=print): if better_model: if self.boredom > 0 and not self.quiet: - print("Patience for training restored.") + _print("Patience for training restored.") self.boredom = 0 self.last_best = epoch - return super().push_epoch(epoch, better_model, metric) + return super().push_epoch(epoch, better_model, metric, _print=_print) @property def max_epochs(self): - return min(self.last_best + self.patience, self._max_epochs) + return min(self.last_best + self.patience + 1, self._max_epochs) -class RaiseBatchSizeOnPlateau: +# Developer note: The inheritance here is only so that pytorch lightning +# readily identifies this as a scheduler. +class RaiseBatchSizeOnPlateau(ReduceLROnPlateau): """ Learning rate scheduler compatible with pytorch schedulers. + Note: The "VERBOSE" Parameter has been deprecated and no longer does anything. + This roughly implements the scheme outlined in the following paper: .. code-block:: none @@ -182,9 +186,20 @@ def __init__( patience=10, threshold=0.0001, threshold_mode="rel", - verbose=True, + verbose=None, # DEPRECATED controller=None, ): + """ + + :param optimizer: + :param max_batch_size: + :param factor: + :param patience: + :param threshold: + :param threshold_mode: + :param verbose: + :param controller: + """ if threshold_mode not in ("abs", "rel"): raise ValueError("Mode must be 'abs' or 'rel'") @@ -195,13 +210,17 @@ def __init__( factor=factor, threshold=threshold, threshold_mode=threshold_mode, - verbose=verbose, ) self.controller = controller self.max_batch_size = max_batch_size self.best_metric = float("inf") self.boredom = 0 self.last_epoch = 0 + warnings.warn("Parameter verbose no longer supported for schedulers. It will be ignored.") + + @property + def optimizer(self): + return self.inner.optimizer def set_controller(self, box): self.controller = box @@ -250,12 +269,9 @@ def step(self, metrics): new_batch_size = min(new_batch_size, self.max_batch_size) self.controller.batch_size = new_batch_size self.boredom = 0 - if self.inner.verbose: - print("Raising batch size to", new_batch_size) + if new_batch_size >= self.max_batch_size: self.inner.last_epoch = self.last_epoch - 1 - if self.inner.verbose: - print("Max batch size reached, Lowering learning rate from here.") return diff --git a/hippynn/experiment/lightning_trainer.py b/hippynn/experiment/lightning_trainer.py new file mode 100644 index 00000000..ead8eb57 --- /dev/null +++ b/hippynn/experiment/lightning_trainer.py @@ -0,0 +1,369 @@ +""" +Pytorch Lightning training interface. + +This module is somewhat experimental. Using pytorch lightning +successfully in a distributed context may require understanding +and adjusting the various settings related to parallelism, e.g. +multiprocessing context, torch ddp backend, and how they interact +with your HPC environment. + +Some features of hippynn experiments may not be implemented yet. + - The plotmaker is currently not supported. + +""" +import warnings +import copy +from pathlib import Path + +import torch + +import pytorch_lightning as pl + +from .routines import TrainingModules +from ..databases import Database +from .routines import SetupParams, setup_training +from ..graphs import GraphModule +from .controllers import Controller +from .metric_tracker import MetricTracker +from .step_functions import get_step_function, StandardStep +from ..tools import print_lr +from . import serialization + + +class HippynnLightningModule(pl.LightningModule): + def __init__( + self, + model: GraphModule, + loss: GraphModule, + eval_loss: GraphModule, + eval_names: list[str], + stopping_key: str, + optimizer_list: list[torch.optim.Optimizer], + scheduler_list: list[torch.optim.lr_scheduler], + controller: Controller, + metric_tracker: MetricTracker, + inputs: list[str], + targets: list[str], + n_outputs: int, + *args, + **kwargs, + ): # forwards args and kwargs to where? + super().__init__() + + self.save_hyperparameters(ignore=["loss", "model", "eval_loss", "controller", "optimizer_list", "scheduler_list"]) + + self.model = model + self.loss = loss + self.eval_loss = eval_loss + self.eval_names = eval_names + self.stopping_key = stopping_key + self.controller = controller + self.metric_tracker = metric_tracker + self.optimizer_list = optimizer_list + self.scheduler_list = scheduler_list + self.inputs = inputs + self.targets = targets + self.n_inputs = len(self.inputs) + self.n_targets = len(self.targets) + self.n_outputs = n_outputs + + self.structure_file = None + + self._last_reload_dlene = None # storage for whether batch size should be changed. + + # Storage for predictions across batches for eval mode. + self.eval_step_outputs = [] + self.controller.optimizer = None + + for optimizer in self.optimizer_list: + if not isinstance(step_fn := get_step_function(optimizer), StandardStep): # := + raise NotImplementedError(f"Optimzers with non-standard steps are not yet supported. {optimizer,step_fn}") + + if args or kwargs: + raise NotImplementedError("Generic args and kwargs not supported.") + + @classmethod + def from_experiment_setup(cls, training_modules: TrainingModules, database: Database, setup_params: SetupParams, **kwargs): + training_modules, controller, metric_tracker = setup_training(training_modules, setup_params) + return cls.from_train_setup(training_modules, database, controller, metric_tracker, **kwargs) + + @classmethod + def from_train_setup( + cls, + training_modules: TrainingModules, + database: Database, + controller: Controller, + metric_tracker: MetricTracker, + callbacks=None, + batch_callbacks=None, + **kwargs, + ): + + model, loss, evaluator = training_modules + + warnings.warn("PytorchLightning hippynn trainer is still experimental.") + + if evaluator.plot_maker is not None: + warnings.warn("plot_maker is not currently supported in pytorch lightning. The current plot_maker will be ignored.") + + trainer = cls( + model=model, + loss=loss, + eval_loss=evaluator.loss, + eval_names=evaluator.loss_names, + optimizer_list=[controller.optimizer], + scheduler_list=controller.scheduler_list, + stopping_key=controller.stopping_key, + controller=controller, + metric_tracker=metric_tracker, + inputs=database.inputs, + targets=database.targets, + n_outputs=evaluator.n_outputs, + **kwargs, + ) + + # pytorch lightning is now in charge of stepping the scheduler. + controller.scheduler_list = [] + + if callbacks is not None or batch_callbacks is not None: + return NotImplemented("arbitrary callbacks are not yet supported with pytorch lightning.") + + return trainer, HippynnDataModule(database, controller.batch_size) + + def on_save_checkpoint(self, checkpoint) -> None: + + # Note to future developers: + # trainer.log_dir property needs to be called on all ranks! This is weird but important; + # do not move trainer.log_dir inside of a rank zero operation! + # see https://github.com/Lightning-AI/pytorch-lightning/discussions/8321 + # Thank you to https://github.com/semaphore-egg . + log_dir = self.trainer.log_dir + + if not self.structure_file: + # Perform change on all ranks. + sf = serialization.DEFAULT_STRUCTURE_FNAME + self.structure_file = sf + + if self.global_rank == 0 and not self.structure_file: + self.print("creating structure file.") + structure = dict( + model=self.model, + loss=self.loss, + eval_loss=self.eval_loss, + controller=self.controller, + optimizer_list=self.optimizer_list, + scheduler_list=self.scheduler_list, + ) + path: Path = Path(log_dir).joinpath(sf) + self.print("Saving structure file at", path) + torch.save(obj=structure, f=path) + + checkpoint["controller_state"] = self.controller.state_dict() + return + + @classmethod + def load_from_checkpoint(cls, checkpoint_path, map_location=None, structure_file=None, hparams_file=None, strict=True, **kwargs): + + if structure_file is None: + # Assume checkpoint_path is like /version_/checkpoints/.chkpt + # and that experiment file is stored at /version_/experiment_structure.pt + structure_file = Path(checkpoint_path) + structure_file = structure_file.parent.parent + structure_file = structure_file.joinpath(serialization.DEFAULT_STRUCTURE_FNAME) + + structure_args = torch.load(structure_file) + + return super().load_from_checkpoint( + checkpoint_path, map_location=map_location, hparams_file=hparams_file, strict=strict, **structure_args, **kwargs + ) + + def on_load_checkpoint(self, checkpoint) -> None: + cstate = checkpoint.pop("controller_state") + self.controller.load_state_dict(cstate) + return + + def configure_optimizers(self): + + scheduler_list = [] + for s in self.scheduler_list: + config = { + "scheduler": s, + "interval": "epoch", # can be epoch or step + "frequency": 1, # How many intervals should pass between calls to `scheduler.step()`. + "monitor": "valid_" + self.stopping_key, # Metric to monitor for schedulers like `ReduceLROnPlateau` + "strict": True, + "name": "learning_rate", + } + scheduler_list.append(config) + + optimizer_list = self.optimizer_list.copy() + + return optimizer_list, scheduler_list + + def on_train_epoch_start(self): + for optimizer in self.optimizer_list: + print_lr(optimizer, print_=self.print) + self.print("Batch size:", self.trainer.train_dataloader.batch_size) + + def training_step(self, batch, batch_idx): + + batch_inputs = batch[: self.n_inputs] + batch_targets = batch[-self.n_targets :] + + batch_model_outputs = self.model(*batch_inputs) + batch_train_loss = self.loss(*batch_model_outputs, *batch_targets)[0] + + self.log("train_loss", batch_train_loss) + return batch_train_loss + + def _eval_step(self, batch, batch_idx): + + batch_inputs = batch[: self.n_inputs] + batch_targets = batch[-self.n_targets :] + + # It is very, very common to fit to derivatives, e.g. force, in hippynn. Override lightning default. + with torch.autograd.set_grad_enabled(True): + batch_predictions = self.model(*batch_inputs) + + batch_predictions = [bp.detach() for bp in batch_predictions] + + outputs = (batch_predictions, batch_targets) + self.eval_step_outputs.append(outputs) + return batch_predictions + + def validation_step(self, batch, batch_idx): + return self._eval_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self._eval_step(batch, batch_idx) + + def _eval_epoch_end(self, prefix): + + all_batch_predictions, all_batch_targets = zip(*self.eval_step_outputs) + # now 'shape' (n_batch, n_outputs) -> need to transpose. + all_batch_predictions = [[bpred[i] for bpred in all_batch_predictions] for i in range(self.n_outputs)] + # now 'shape' (n_batch, n_targets) -> need to transpose. + all_batch_targets = [[bpred[i] for bpred in all_batch_targets] for i in range(self.n_targets)] + + # now cat each prediction and target across the batch index. + all_predictions = [torch.cat(x, dim=0) if x[0].shape != () else x[0] for x in all_batch_predictions] + all_targets = [torch.cat(x, dim=0) for x in all_batch_targets] + + all_losses = [x.item() for x in self.eval_loss(*all_predictions, *all_targets)] + self.eval_step_outputs.clear() # free memory + + loss_dict = {name: value for name, value in zip(self.eval_names, all_losses)} + + self.log_dict({prefix + k: v for k, v in loss_dict.items()}, sync_dist=True) + + return + + def on_validation_epoch_end(self): + self._eval_epoch_end(prefix="valid_") + return + + def on_test_epoch_end(self): + self._eval_epoch_end(prefix="test_") + return + + def _eval_end(self, prefix, when=None) -> None: + if when is None: + if self.trainer.sanity_checking: + when = "Sanity Check" + else: + when = self.current_epoch + + # Step 1: get metrics reduced from all ranks. + # Copied pattern from pytorch_lightning. + metrics = copy.deepcopy(self.trainer.callback_metrics) + + pre_len = len(prefix) + loss_dict = {k[pre_len:]: v.item() for k, v in metrics.items() if k.startswith(prefix)} + + loss_dict = {prefix[:-1]: loss_dict} # strip underscore from prefix and wrap. + + if self.trainer.sanity_checking: + self.print("Sanity check metric values:") + self.metric_tracker.evaluation_print(loss_dict, _print=self.print) + return + + # Step 2: register metrics + out_ = self.metric_tracker.register_metrics(loss_dict, when=when) + better_metrics, better_model, stopping_metric = out_ + self.metric_tracker.evaluation_print_better(loss_dict, better_metrics, _print=self.print) + + continue_training = self.controller.push_epoch(self.current_epoch, better_model, stopping_metric, _print=self.print) + + if not continue_training: + self.print("Controller is terminating training.") + self.trainer.should_stop = True + + # Step 3: Logic for changing the batch size without always requiring new dataloaders. + # Step 3a: don't do this when not testing. + if not self.trainer.training: + return + + controller_batch_size = self.controller.batch_size + trainer_batch_size = self.trainer.train_dataloader.batch_size + if controller_batch_size != trainer_batch_size: + # Need to trigger a batch size change. + if self._last_reload_dlene is None: + # save the original value of this variable to the pl module + self._last_reload_dlene = self.trainer.reload_dataloaders_every_n_epochs + + # TODO: Make this run even if there isn't an explicit datamodule? + self.trainer.datamodule.batch_size = controller_batch_size + # Tell PL lightning to reload the dataloaders now. + self.trainer.reload_dataloaders_every_n_epochs = 1 + + elif self._last_reload_dlene is not None: + # Restore the last saved value from the pl module. + self.trainer.reload_dataloaders_every_n_epochs = self._last_reload_dlene + self._last_reload_dlene = None + else: + # Batch sizes match, and there's no variable to restore. + pass + return + + def on_validation_end(self): + self._eval_end(prefix="valid_") + return + + def on_test_end(self): + self._eval_end(prefix="test_", when="test") + return + + +class LightingPrintStagesCallback(pl.Callback): + """ + This callback is for debugging only. + It prints whenever a callback stage is entered in pytorch lightning. + """ + + for k in dir(pl.Callback): + if k.startswith("on_"): + + def some_method(self, *args, _k=k, **kwargs): + all_args = kwargs.copy() + all_args.update({i: a for i, a in enumerate(args)}) + int_args = {k: v for k, v in all_args.items() if isinstance(v, int)} + print("Callback stage:", _k, "with integer arguments:", int_args) + + exec(f"{k} = some_method") + del some_method + + +class HippynnDataModule(pl.LightningDataModule): + def __init__(self, database: Database, batch_size): + super().__init__() + self.database = database + self.batch_size = batch_size + + def train_dataloader(self): + return self.database.make_generator("train", "train", self.batch_size) + + def val_dataloader(self): + return self.database.make_generator("valid", "eval", self.batch_size) + + def test_dataloader(self): + return self.database.make_generator("test", "eval", self.batch_size) diff --git a/hippynn/experiment/metric_tracker.py b/hippynn/experiment/metric_tracker.py index f43426e6..af28d7fc 100644 --- a/hippynn/experiment/metric_tracker.py +++ b/hippynn/experiment/metric_tracker.py @@ -85,7 +85,6 @@ def register_metrics(self, metric_info, when): except KeyError: if split_type not in self.best_metric_values: # Haven't seen this split before! - print("ADDING ",split_type) self.best_metric_values[split_type] = {} better_metrics[split_type] = {} better = True # old best was not found! @@ -99,7 +98,7 @@ def register_metrics(self, metric_info, when): else: self.other_metric_values[when] = metric_info - if self.stopping_key: + if self.stopping_key and "valid" in metric_info: better_model = better_metrics.get("valid", {}).get(self.stopping_key, False) stopping_key_metric = metric_info["valid"][self.stopping_key] else: @@ -108,21 +107,21 @@ def register_metrics(self, metric_info, when): return better_metrics, better_model, stopping_key_metric - def evaluation_print(self, evaluation_dict, quiet=None): + def evaluation_print(self, evaluation_dict, quiet=None, _print=print): if quiet is None: quiet = self.quiet if quiet: return - table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width) + table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width, _print=_print) - def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None): + def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None, _print=print): if quiet is None: quiet = self.quiet if quiet: return - table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width) + table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width, _print=print) if self.stopping_key: - print( + _print( "Best {} so far: {:>8.5g}".format( self.stopping_key, self.best_metric_values["valid"][self.stopping_key] ) @@ -134,7 +133,7 @@ def plot_over_time(self): # Driver for printing evaluation table results, with * for better entries. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns): +def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns, _print=print): """ Print metric results as a table, add a '*' character for metrics in better_dict. @@ -157,16 +156,16 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_ header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) rowstring = "{:<" + str(n_columns) + "}: " + " {}{:>10.5g}" * n_types - print(header) - print("-" * len(header)) + _print(header) + _print("-" * len(header)) for n, valsbet in zip(metric_names, transposed_values_better): rowoutput = [k for bv in valsbet for k in bv] - print(rowstring.format(n, *rowoutput)) + _print(rowstring.format(n, *rowoutput)) # Driver for printing evaluation table results. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print(evaluation_dict, metric_names, n_columns): +def table_evaluation_print(evaluation_dict, metric_names, n_columns, _print=print): """ Print metric results as a table. @@ -184,8 +183,8 @@ def table_evaluation_print(evaluation_dict, metric_names, n_columns): header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) rowstring = "{:<" + str(n_columns) + "}: " + " {:>10.5g}" * n_types - print(header) - print("-" * len(header)) + _print(header) + _print("-" * len(header)) for n, vals in zip(metric_names, transposed_values): - print(rowstring.format(n, *vals)) - print("-" * len(header)) + _print(rowstring.format(n, *vals)) + _print("-" * len(header)) diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index f6aee191..84faa5c0 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -306,9 +306,7 @@ def train_model( print("Finishing up...") print("Training phase ended.") - if store_metrics: - with open("training_metrics.pkl", "wb") as pfile: - pickle.dump(metric_tracker, pfile) + torch.save(metric_tracker, "training_metrics.pt") best_model = metric_tracker.best_model if best_model: @@ -448,6 +446,7 @@ def training_loop( qprint("_" * 50) qprint("Epoch {}:".format(epoch)) tools.print_lr(optimizer) + qprint("Batch Size:", controller.batch_size) qprint(flush=True, end="") diff --git a/hippynn/experiment/serialization.py b/hippynn/experiment/serialization.py index c4d73c1a..326812fa 100644 --- a/hippynn/experiment/serialization.py +++ b/hippynn/experiment/serialization.py @@ -1,5 +1,7 @@ """ -checkpoint and state generation +Checkpoint and state generation. + +As a user, in most cases you will only need the `load` functions here. """ from typing import Tuple, Union @@ -12,7 +14,7 @@ from ..graphs import GraphModule from ..tools import device_fallback from .assembly import TrainingModules -from .controllers import PatienceController +from .controllers import Controller from .device import set_devices from .metric_tracker import MetricTracker @@ -21,13 +23,13 @@ def create_state( model: GraphModule, - controller: PatienceController, + controller: Controller, metric_tracker: MetricTracker, ) -> dict: """Create an experiment state dictionary. :param model: current model - :param controller: patience controller + :param controller: controller :param metric_tracker: current metrics :return: dictionary containing experiment state. :rtype: dict @@ -43,7 +45,7 @@ def create_state( def create_structure_file( training_modules: TrainingModules, database: Database, - controller: PatienceController, + controller: Controller, fname=DEFAULT_STRUCTURE_FNAME, ) -> None: """ @@ -51,7 +53,7 @@ def create_structure_file( :param training_modules: contains model, controller, and loss :param database: database for training - :param controller: patience controller + :param controller: controller :param fname: filename to save the checkpoint :return: None diff --git a/hippynn/graphs/gops.py b/hippynn/graphs/gops.py index ee98ddc1..01fc6682 100644 --- a/hippynn/graphs/gops.py +++ b/hippynn/graphs/gops.py @@ -50,7 +50,8 @@ def compute_evaluation_order(all_nodes): evaluation_inputs_list = [] evaluation_outputs_list = [] - unsatisfied_nodes = all_nodes.copy() + # need to sort to get stable results between runs/processes. + unsatisfied_nodes = list(sorted(all_nodes, key=lambda node: node.name)) satisfied_nodes = set() n = -1 while len(unsatisfied_nodes) > 0: diff --git a/hippynn/interfaces/ase_interface/ase_database.py b/hippynn/interfaces/ase_interface/ase_database.py index b3c05057..992e16f4 100644 --- a/hippynn/interfaces/ase_interface/ase_database.py +++ b/hippynn/interfaces/ase_interface/ase_database.py @@ -24,14 +24,14 @@ import os import numpy as np -from ase.io import read +from ase.io import read, iread -from ...tools import np_of_torchdefaultdtype +from ...tools import np_of_torchdefaultdtype, progress_bar from ...databases.database import Database from ...databases.restarter import Restartable from typing import Union from typing import List - +import hippynn.tools class AseDatabase(Database, Restartable): """ @@ -84,11 +84,11 @@ def load_arrays(self, directory, filename, inputs, targets, quiet=False, allow_u var_list = inputs + targets try: if isinstance(filename, str): - db = read(directory + filename, index=":") + db = list(progress_bar(iread(directory+filename,index=":"), desc='configs'))#read(directory + filename, index=":") elif isinstance(filename, (list, np.ndarray)): db = [] - for name in filename: - temp_db = read(directory + name, index=":") + for name in progress_bar(filename, desc='files'): + temp_db = list(progress_bar(iread(directory + name, index=":"), desc='configs')) db += temp_db except FileNotFoundError as fee: raise FileNotFoundError( diff --git a/hippynn/layers/hiplayers.py b/hippynn/layers/hiplayers.py index 2f62c07d..b93aae60 100644 --- a/hippynn/layers/hiplayers.py +++ b/hippynn/layers/hiplayers.py @@ -275,16 +275,26 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) n_atoms_real = in_features.shape[0] sense_vals = self.sensitivity(dist_pairs) + # Sensitivity stacking + sense_vec = sense_vals.unsqueeze(1) * (coord_pairs / dist_pairs.unsqueeze(1)).unsqueeze(2) + sense_vec = sense_vec.reshape(-1, self.n_dist * 3) + sense_stacked = torch.concatenate([sense_vals, sense_vec], dim=1) + + # Message passing, stack sensitivities to coalesce custom kernel call. + # shape (n_atoms, n_nu + 3*n_nu, n_feat) + env_features_stacked = custom_kernels.envsum(sense_stacked, in_features, pair_first, pair_second) + # shape (n_atoms, 4, n_nu, n_feat) + env_features_stacked = env_features_stacked.reshape(-1, 4, self.n_dist, self.nf_in) + + # separate to tensor components + env_features, env_features_vec = torch.split(env_features_stacked, [1, 3], dim=1) + # Scalar part - env_features = custom_kernels.envsum(sense_vals, in_features, pair_first, pair_second) env_features = torch.reshape(env_features, (n_atoms_real, self.n_dist * self.nf_in)) weights_rs = torch.reshape(self.int_weights.permute(0, 2, 1), (self.n_dist * self.nf_in, self.nf_out)) features_out = torch.mm(env_features, weights_rs) # Vector part - sense_vec = sense_vals.unsqueeze(1) * (coord_pairs / dist_pairs.unsqueeze(1)).unsqueeze(2) - sense_vec = sense_vec.reshape(-1, self.n_dist * 3) - env_features_vec = custom_kernels.envsum(sense_vec, in_features, pair_first, pair_second) env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in) features_out_vec = torch.mm(env_features_vec, weights_rs) features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out) @@ -315,19 +325,41 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) n_atoms_real = in_features.shape[0] sense_vals = self.sensitivity(dist_pairs) - # Scalar part - env_features = custom_kernels.envsum(sense_vals, in_features, pair_first, pair_second) + #### + # Sensitivity calculations + # scalar: sense_vals + # vector: sense_vec + # quadrupole: sense_quad + rhats = coord_pairs / dist_pairs.unsqueeze(1) + sense_vec = sense_vals.unsqueeze(1) * rhats.unsqueeze(2) + sense_vec = sense_vec.reshape(-1, self.n_dist * 3) + rhatsquad = rhats.unsqueeze(1) * rhats.unsqueeze(2) + rhatsquad = (rhatsquad + rhatsquad.transpose(1, 2)) / 2 + tr = torch.diagonal(rhatsquad, dim1=1, dim2=2).sum(dim=1) / 3.0 # Add divide by 3 early to save flops + tr = tr.unsqueeze(1).unsqueeze(2) * torch.eye(3, dtype=tr.dtype, device=tr.device).unsqueeze(0) + rhatsquad = rhatsquad - tr + rhatsqflat = rhatsquad.reshape(-1, 9)[:, self.upper_ind] # Upper-diagonal part + sense_quad = sense_vals.unsqueeze(1) * rhatsqflat.unsqueeze(2) + sense_quad = sense_quad.reshape(-1, self.n_dist * 5) + sense_stacked = torch.concatenate([sense_vals, sense_vec, sense_quad], dim=1) + + # Message passing, stack sensitivities to coalesce custom kernel call. + # shape (n_atoms, n_nu + 3*n_nu + 5*n_nu, n_feat) + env_features_stacked = custom_kernels.envsum(sense_stacked, in_features, pair_first, pair_second) + # shape (n_atoms, 9, n_nu, n_feat) + env_features_stacked = env_features_stacked.reshape(-1, 9, self.n_dist, self.nf_in) + + # separate to tensor components + env_features, env_features_vec, env_features_quad = torch.split(env_features_stacked, [1, 3, 5], dim=1) + + # Scalar stuff. env_features = torch.reshape(env_features, (n_atoms_real, self.n_dist * self.nf_in)) weights_rs = torch.reshape(self.int_weights.permute(0, 2, 1), (self.n_dist * self.nf_in, self.nf_out)) features_out = torch.mm(env_features, weights_rs) # Vector part # Sensitivity - rhats = coord_pairs / dist_pairs.unsqueeze(1) - sense_vec = sense_vals.unsqueeze(1) * rhats.unsqueeze(2) - sense_vec = sense_vec.reshape(-1, self.n_dist * 3) # Weights - env_features_vec = custom_kernels.envsum(sense_vec, in_features, pair_first, pair_second) env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in) features_out_vec = torch.mm(env_features_vec, weights_rs) # Norm and scale @@ -338,16 +370,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) # Quadrupole part # Sensitivity - rhatsquad = rhats.unsqueeze(1) * rhats.unsqueeze(2) - rhatsquad = (rhatsquad + rhatsquad.transpose(1, 2)) / 2 - tr = torch.diagonal(rhatsquad, dim1=1, dim2=2).sum(dim=1) / 3.0 # Add divide by 3 early to save flops - tr = tr.unsqueeze(1).unsqueeze(2) * torch.eye(3, dtype=tr.dtype, device=tr.device).unsqueeze(0) - rhatsquad = rhatsquad - tr - rhatsqflat = rhatsquad.reshape(-1, 9)[:, self.upper_ind] # Upper-diagonal part - sense_quad = sense_vals.unsqueeze(1) * rhatsqflat.unsqueeze(2) - sense_quad = sense_quad.reshape(-1, self.n_dist * 5) # Weights - env_features_quad = custom_kernels.envsum(sense_quad, in_features, pair_first, pair_second) env_features_quad = env_features_quad.reshape(n_atoms_real * 5, self.n_dist * self.nf_in) features_out_quad = torch.mm(env_features_quad, weights_rs) ##sum v b features_out_quad = features_out_quad.reshape(n_atoms_real, 5, self.nf_out) @@ -359,6 +382,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) # Scales features_out_quad = features_out_quad * self.quadscales.unsqueeze(0) + # Combine features_out_selfpart = self.selfint(in_features) features_out_total = features_out + features_out_vec + features_out_quad + features_out_selfpart diff --git a/hippynn/pretraining.py b/hippynn/pretraining.py index e039905a..726186fc 100644 --- a/hippynn/pretraining.py +++ b/hippynn/pretraining.py @@ -70,7 +70,7 @@ def hierarchical_energy_initialization( if not eo_layer.weight.data.shape[-1] == eovals.shape[-1]: raise ValueError("The shape of the computed E0 values does not match the shape expected by the model.") - eo_layer.weight.data = eovals.reshape(1,-1) + eo_layer.weight.data = eovals.reshape(1, -1) print("Computed E0 energies:", eovals) eo_layer.weight.data = eovals.expand_as(eo_layer.weight.data) eo_layer.weight.requires_grad_(trainable_after) diff --git a/hippynn/tools.py b/hippynn/tools.py index d2c78133..df1507cb 100644 --- a/hippynn/tools.py +++ b/hippynn/tools.py @@ -133,9 +133,9 @@ def arrdict_len(array_dictionary): return len(next(iter(array_dictionary.values()))) -def print_lr(optimizer): +def print_lr(optimizer, print_=print): for i, param_group in enumerate(optimizer.param_groups): - print("Learning rate:{:>10.5g}".format(param_group["lr"])) + print_("Learning rate:{:>10.5g}".format(param_group["lr"])) def isiterable(obj): @@ -217,3 +217,18 @@ def is_equal_state_dict(d1, d2, raise_where=False): return True +def recursive_param_count(state_dict, n=0): + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + n += v.numel() + elif isinstance(v, dict): + n += recursive_param_count(v) + elif isinstance(v, (list, tuple)): + n += recursive_param_count({i: x for i, x in enumerate(v)}) + elif isinstance(v, (float, int)): + n += 1 + elif v is None: + pass + else: + raise TypeError(f'Unknown type {type(v)=}, value={v}') + return n diff --git a/setup.py b/setup.py index 3f0100a4..95d1333e 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ "tqdm", "graphviz", "h5py", + "lightning", ] setuptools.setup( diff --git a/tests/lightning_QM7_test.py b/tests/lightning_QM7_test.py new file mode 100644 index 00000000..ac1d0b0f --- /dev/null +++ b/tests/lightning_QM7_test.py @@ -0,0 +1,219 @@ +""" + +This is a test script based on /examples/QM7_example.py which uses pytorch lightning to train. + +""" + +PERFORM_PLOTTING = True # Make sure you have matplotlib if you want to set this to TRUE + +#### Setup pytorch things +import torch + +torch.set_default_dtype(torch.float32) + +if torch.cuda.is_available(): + torch.cuda.set_device(0) # Don't try this if you want CPU training! + +import hippynn + + +def main(): + hippynn.settings.WARN_LOW_DISTANCES = False + + # Note: these settings may need to be adjusted depending on the platform where + # this code is run. + n_devices = 2 + num_workers = 0 + multiprocessing_context = "fork" + + # Hyperparameters for the network + netname = "TEST_LIGHTNING_MODEL" + network_params = { + "possible_species": [0, 1, 6, 7, 8, 16], # Z values of the elements + "n_features": 20, # Number of neurons at each layer + "n_sensitivities": 20, # Number of sensitivity functions in an interaction layer + "dist_soft_min": 1.6, # + "dist_soft_max": 10.0, + "dist_hard_max": 12.5, + "n_interaction_layers": 2, # Number of interaction blocks + "n_atom_layers": 3, # Number of atom layers in an interaction block + } + + # Define a model + + from hippynn.graphs import inputs, networks, targets, physics + + # model inputs + species = inputs.SpeciesNode(db_name="Z") + positions = inputs.PositionsNode(db_name="R") + + # Model computations + network = networks.HipnnVec("HIPNN", (species, positions), module_kwargs=network_params) + henergy = targets.HEnergyNode("HEnergy", network) + molecule_energy = henergy.mol_energy + molecule_energy.db_name = "T" + hierarchicality = henergy.hierarchicality + + # define loss quantities + from hippynn.graphs import loss + + rmse_energy = loss.MSELoss.of_node(molecule_energy) ** (1 / 2) + mae_energy = loss.MAELoss.of_node(molecule_energy) + rsq_energy = loss.Rsq.of_node(molecule_energy) + + ### More advanced usage of loss graph + + pred_per_atom = physics.PerAtom("PeratomPredicted", (molecule_energy, species)).pred + true_per_atom = physics.PerAtom("PeratomTrue", (molecule_energy.true, species.true)) + mae_per_atom = loss.MAELoss(pred_per_atom, true_per_atom) + + ### End more advanced usage of loss graph + + loss_error = rmse_energy + mae_energy + + rbar = loss.Mean.of_node(hierarchicality) + l2_reg = loss.l2reg(network) + loss_regularization = 1e-6 * l2_reg + rbar # L2 regularization and hierarchicality regularization + + train_loss = loss_error + loss_regularization + + # Validation losses are what we check on the data between epochs -- we can only train to + # a single loss, but we can check other metrics too to better understand how the model is training. + # There will also be plots of these things over time when training completes. + validation_losses = { + "T-RMSE": rmse_energy, + "T-MAE": mae_energy, + "T-RSQ": rsq_energy, + "TperAtom MAE": mae_per_atom, + "T-Hier": rbar, + "L2Reg": l2_reg, + "Loss-Err": loss_error, + "Loss-Reg": loss_regularization, + "Loss": train_loss, + } + early_stopping_key = "Loss-Err" + + if PERFORM_PLOTTING: + + from hippynn import plotting + + plot_maker = plotting.PlotMaker( + # Simple plots which compare the network to the database + plotting.Hist2D.compare(molecule_energy, saved=True), + # Slightly more advanced control of plotting! + plotting.Hist2D( + true_per_atom, + pred_per_atom, + xlabel="True Energy/Atom", + ylabel="Predicted Energy/Atom", + saved="PerAtomEn.pdf", + ), + plotting.HierarchicalityPlot(hierarchicality.pred, molecule_energy.pred - molecule_energy.true, saved="HierPlot.pdf"), + plot_every=10, # How often to make plots -- here, epoch 0, 10, 20... + ) + else: + plot_maker = None + + from hippynn.experiment import assemble_for_training + + # This piece of code glues the stuff together as a pytorch model, + # dropping things that are irrelevant for the losses defined. + training_modules, db_info = assemble_for_training(train_loss, validation_losses, plot_maker=plot_maker) + training_modules[0].print_structure() + + if num_workers > 0: + dataloader_kwargs = dict(multiprocessing_context=multiprocessing_context, persistent_workers=True) + else: + dataloader_kwargs = None + database_params = { + "name": "qm7", # Prefix for arrays in folder + "directory": "../../datasets/qm7_processed", + "quiet": False, + "test_size": 0.1, + "valid_size": 0.1, + "seed": 2001, + # How many samples from the training set to use during evaluation + **db_info, # Adds the inputs and targets names from the model as things to load + "dataloader_kwargs": dataloader_kwargs, + "num_workers": num_workers, + } + + from hippynn.databases import DirectoryDatabase + + database = DirectoryDatabase(**database_params) + + # Now that we have a database and a model, we can + # Fit the non-interacting energies by examining the database. + + from hippynn.pretraining import hierarchical_energy_initialization + + hierarchical_energy_initialization(henergy, database, trainable_after=False) + + from hippynn.experiment.controllers import PatienceController + from torch.optim.lr_scheduler import ReduceLROnPlateau + + optimizer = torch.optim.Adam(training_modules.model.parameters(), lr=1e-3) + + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + factor=0.5, + patience=1, + ) + + controller = PatienceController( + optimizer=optimizer, + scheduler=scheduler, + batch_size=16, # start batch size + eval_batch_size=16, + max_epochs=3, + termination_patience=10, + fraction_train_eval=0.1, + stopping_key=early_stopping_key, + ) + + experiment_params = hippynn.experiment.SetupParams( + controller=controller, + ) + + from hippynn.experiment import HippynnLightningModule + + lightmod, datamodule = HippynnLightningModule.from_experiment_setup(training_modules, database, experiment_params) + import pytorch_lightning as pl + from pytorch_lightning.loggers import CSVLogger + + logger = CSVLogger(save_dir=".", name=netname, flush_logs_every_n_steps=100) + from pytorch_lightning.callbacks import ModelCheckpoint + + checkpointer = ModelCheckpoint( + monitor=f"valid_{early_stopping_key}", + save_last=True, + save_top_k=5, + every_n_epochs=1, + every_n_train_steps=None, + ) + + from hippynn.experiment.lightning_trainer import LightingPrintStagesCallback + + cb = LightingPrintStagesCallback() # include this callback if you aren't sure what stage of lightning is broken. + + # The default accelerator, 'auto' detects MPS on mac. hippynn doesn't work on MPS (yet). + # So we set cpu here. + trainer = pl.Trainer( + accelerator="cpu", + logger=logger, + num_nodes=1, + devices=n_devices, + callbacks=[checkpointer], + log_every_n_steps=1, + max_epochs=-1, # This is set this way because the hippynn controller should terminate training. + ) + + trainer.fit( + model=lightmod, + datamodule=datamodule, + ) + trainer.test(datamodule=datamodule, ckpt_path="best") + + +if __name__ == "__main__": + main()