Skip to content

Commit

Permalink
Add pytorch lightning trainer (#99)
Browse files Browse the repository at this point in the history
* initial attempt of lightning training interface

* fix train_step and remove print

* fix batch order for end validation epoch

* fix types

* fix raisebatchsize for lightning

* remember to detach tensors

* add valid tag to lr scheduler

* add loss printing and controller

* add dataloader args for additional configuration

* refactor slightly and fix type errors

* add extra dataloader args to test script

* closer to connecting controller to lightning module

* guard print statements

* prevent double-updating of schedulers

* add sanity checking guard

* fix printing in sanity check

* get batch size changes working with pytorch lightning

* make sure custom kernels don't automatically trigger cuda context on device 0

* adding saving of modules (very necessary!)

* make lightning trainer not have to serialize constantly

* add coalescing custom kernel call for hip-nn-ts (l=2)

* add coalescing custom kernel call to hip-nn-ts, l=1

* formating and debug print

* update packages in docs

* make lightning import optional

* make metric tracker only seek better metrics on validation

* Make controller and metric tracker see metrics reduced across nodes

* update lightning test script

* update docs and requirements, remove extraneous code

* good old fashioned formatting

---------

Co-authored-by: Nicholas Lubbers <hippynn@lanl.gov>
  • Loading branch information
lubbersnick and Nicholas Lubbers authored Sep 7, 2024
1 parent 4e84c36 commit 0004728
Show file tree
Hide file tree
Showing 23 changed files with 912 additions and 108 deletions.
1 change: 1 addition & 0 deletions conda_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ ase
h5py
tqdm
python-graphviz
lightning
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

project = "hippynn"
copyright = "2019, Los Alamos National Laboratory"
author = "Nicholas Lubbers"
author = "Nicholas Lubbers et al"

# The full version, including alpha/beta/rc tags
import hippynn
Expand Down Expand Up @@ -47,7 +47,7 @@
}

# The following are highly optional, so we mock them for doc purposes.
autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba"]
autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning"]


# -- Options for HTML output -------------------------------------------------
Expand Down
17 changes: 10 additions & 7 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@ Requirements:
* Python_ >= 3.9
* pytorch_ >= 1.9
* numpy_

Optional Dependencies:
* triton_ (recommended, for improved GPU performance)
* numba_ (recommended for improved CPU performance)
* cupy_ (Alternative for accelerating GPU performance)
* ASE_ (for usage with ase)
* cupy_ (alternative for accelerating GPU performance)
* ASE_ (for usage with ase and other misc. features)
* matplotlib_ (for plotting)
* tqdm_ (for progress bars)
* graphviz_ (for viewing model graphs as figures)
* graphviz_ (for visualizing model graphs)
* h5py_ (for loading ani-h5 datasets)
* pyanitools_ (for loading ani-h5 datasets)
* pytorch-lightning_ (for distributed training)

Interfacing codes:
* ASE_
Expand All @@ -40,7 +42,7 @@ Interfacing codes:
.. _ASE: https://wiki.fysik.dtu.dk/ase/
.. _LAMMPS: https://www.lammps.org/
.. _PYSEQM: https://github.com/lanl/PYSEQM

.. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning

Installation Instructions
^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -67,9 +69,6 @@ Clone the hippynn_ repository and navigate into it, e.g.::

.. _hippynn: https://github.com/lanl/hippynn/

.. note::
If you wish to do a cpu-only install, you may need to comment
out ``cupy`` from the conda_requirements.txt file.

Dependencies using conda
........................
Expand All @@ -78,6 +77,10 @@ Install dependencies from conda using recommended channels::

$ conda install -c pytorch -c conda-forge --file conda_requirements.txt

.. note::
If you wish to do a cpu-only install, you may need to comment
out ``cupy`` from the conda_requirements.txt file.

Dependencies using pip
.......................

Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The following settings are available:
- Dynamic
* - PROGRESS
- Progress bars function during training, evaluation, and prediction
- tqdm, none
- tqdm, none, or floating point string specifying default update rate in seconds (default 1).
- tqdm
- Yes, but assign this to a generator-wrapper such as ``tqdm.tqdm``, or with a python ``None`` to disable. The wrapper must accept ``tqdm`` arguments, although it technically doesn't have to do anything with them.
* - DEFAULT_PLOT_FILETYPE
Expand Down
102 changes: 102 additions & 0 deletions examples/barebones_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
'''
To obtain the data files needed for this example, use the script process_QM7_data.py,
also located in this folder. The script contains further instructions for use.
'''

import torch

# Setup pytorch things
torch.set_default_dtype(torch.float32)

import hippynn

netname = "TEST_BAREBONES_LIGHTNING_SCRIPT"

# Hyperparameters for the network
# These are set deliberately small so that you can easily run the example on a laptop or similar.
network_params = {
"possible_species": [0, 1, 6, 7, 8, 16], # Z values of the elements in QM7
"n_features": 20, # Number of neurons at each layer
"n_sensitivities": 20, # Number of sensitivity functions in an interaction layer
"dist_soft_min": 1.6, # qm7 is in Bohr!
"dist_soft_max": 10.0,
"dist_hard_max": 12.5,
"n_interaction_layers": 2, # Number of interaction blocks
"n_atom_layers": 3, # Number of atom layers in an interaction block
}

# Define a model
from hippynn.graphs import inputs, networks, targets, physics

species = inputs.SpeciesNode(db_name="Z")
positions = inputs.PositionsNode(db_name="R")

network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params)
henergy = targets.HEnergyNode("HEnergy", network, db_name="T")
# hierarchicality = henergy.hierarchicality

# define loss quantities
from hippynn.graphs import loss

mse_energy = loss.MSELoss.of_node(henergy)
mae_energy = loss.MAELoss.of_node(henergy)
rmse_energy = mse_energy ** (1 / 2)

# Validation losses are what we check on the data between epochs -- we can only train to
# a single loss, but we can check other metrics too to better understand how the model is training.
# There will also be plots of these things over time when training completes.
validation_losses = {
"RMSE": rmse_energy,
"MAE": mae_energy,
"MSE": mse_energy,
}

# This piece of code glues the stuff together as a pytorch model,
# dropping things that are irrelevant for the losses defined.
training_modules, db_info = hippynn.experiment.assemble_for_training(mse_energy, validation_losses)

# Go to a directory for the model.
# hippynn will save training files in the current working directory.
with hippynn.tools.active_directory(netname):
# Log the output of python to `training_log.txt`
with hippynn.tools.log_terminal("training_log.txt", "wt"):
database = hippynn.databases.DirectoryDatabase(
name="data-qm7", # Prefix for arrays in the directory
directory="../../../datasets/qm7_processed",
test_size=0.1, # Fraction or number of samples to test on
valid_size=0.1, # Fraction or number of samples to validate on
seed=2001, # Random seed for splitting data
**db_info, # Adds the inputs and targets db_names from the model as things to load
dataloader_kwargs=dict(persistent_workers=True,multiprocessing_context='fork'),
num_workers=2,
)

# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
# This tends to stabilize training a lot.
from hippynn.pretraining import hierarchical_energy_initialization

hierarchical_energy_initialization(henergy, database, trainable_after=False)

# Parameters describing the training procedure.
from hippynn.experiment import setup_and_train

experiment_params = hippynn.experiment.SetupParams(
stopping_key="MSE", # The name in the validation_losses dictionary.
batch_size=12,
optimizer=torch.optim.Adam,
max_epochs=100,
learning_rate=0.001,
)
# setup_and_train(
# training_modules=training_modules,
# database=database,
# setup_params=experiment_params,
# )
from hippynn.experiment import HippynnLightningModule

# lightning needs to run exactly where the script is located in distributed modes.
lightmod, datamodule = HippynnLightningModule.from_experiment_setup(training_modules, database, experiment_params)
import pytorch_lightning as pl
trainer = pl.Trainer(accelerator='cpu') #'auto' detects MPS which doesn't work.
trainer.fit(model=lightmod, datamodule=datamodule)
30 changes: 22 additions & 8 deletions hippynn/_settings_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,22 @@
TQDM_PROGRESS = None

if TQDM_PROGRESS is not None:
TQDM_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False)

DEFAULT_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False)
else:
DEFAULT_PROGRESS = None
### Progress handlers


def progress_handler(prog_str):
if prog_str == "tqdm":
return TQDM_PROGRESS
if prog_str.lower() == "none":
return DEFAULT_PROGRESS
elif prog_str.lower() == "none":
return None
else:
try:
prog_float = float(prog_str)
return partial(TQDM_PROGRESS, mininterval=prog_float, leave=False)
except:
pass
warnings.warn(f"Unrecognized progress setting: '{prog_str}'. Setting to none.")


Expand All @@ -63,7 +69,7 @@ def kernel_handler(kernel_string):

# keys: defaults, types, and handlers
default_settings = {
"PROGRESS": (TQDM_PROGRESS, progress_handler),
"PROGRESS": (DEFAULT_PROGRESS, progress_handler),
"DEFAULT_PLOT_FILETYPE": (".pdf", str),
"TRANSPARENT_PLOT": (False, strtobool),
"DEBUG_LOSS_BROADCAST": (False, strtobool),
Expand All @@ -85,11 +91,16 @@ def kernel_handler(kernel_string):
config_sources = {} # Dictionary of configuration variable sources mapping to dictionary of configuration.
# We add to this dictionary in order of application

SECTION_NAME = "GLOBALS"

rc_name = os.path.expanduser("~/.hippynnrc")
if os.path.exists(rc_name) and os.path.isfile(rc_name):
config = configparser.ConfigParser(inline_comment_prefixes="#")
config.read(rc_name)
config_sources["~/.hippynnrc"] = config["GLOBALS"]
if SECTION_NAME not in config:
warnings.warn(f"Config file {rc_name} does not contain a {SECTION_NAME} section and will be ignored!")
else:
config_sources["~/.hippynnrc"] = config[SECTION_NAME]

SETTING_PREFIX = "HIPPYNN_"
hippynn_environment_variables = {
Expand All @@ -103,7 +114,10 @@ def kernel_handler(kernel_string):
if os.path.exists(local_rc_fname) and os.path.isfile(local_rc_fname):
local_config = configparser.ConfigParser()
local_config.read(local_rc_fname)
config_sources[LOCAL_RC_FILE_KEY] = local_config["GLOBALS"]
if SECTION_NAME not in local_config:
warnings.warn(f"Config file {local_rc_fname} does not contain a {SECTION_NAME} section and will be ignored!")
else:
config_sources[LOCAL_RC_FILE_KEY] = local_config[SECTION_NAME]
else:
warnings.warn(f"Local configuration file {local_rc_fname} not found.")

Expand Down
8 changes: 6 additions & 2 deletions hippynn/custom_kernels/tensor_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def _numba_gpu_not_found(*args, **kwargs):
class NumbaCompatibleTensorFunction:
def __init__(self):
if numba.cuda.is_available():
self.kernel64 = self.make_kernel(numba.float64)
self.kernel32 = self.make_kernel(numba.float32)
self.kernel64 = None
self.kernel32 = None
else:
self.kernel64 = _numba_gpu_not_found
self.kernel32 = _numba_gpu_not_found
Expand All @@ -59,8 +59,12 @@ def __call__(self, *args, **kwargs):
with numba.cuda.gpus[dev.index]:
numba_args = batch_convert_torch_to_numba(*args)
if dtype == torch.float64:
if self.kernel64 is None:
self.kernel64 = self.make_kernel(numba.float64)
self.kernel64[launch_bounds](*numba_args)
elif dtype == torch.float32:
if self.kernel32 is None:
self.kernel32 = self.make_kernel(numba.float32)
self.kernel32[launch_bounds](*numba_args)
else:
raise ValueError("Bad dtype: {}".format(dtype))
Expand Down
7 changes: 7 additions & 0 deletions hippynn/custom_kernels/test_env_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printi
TEST_LARGE_PARAMS = dict(n_molecules=1000, n_atoms=30, atom_prob=0.7, n_features=80, n_nu=20)
TEST_MEGA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=100)
TEST_ULTRA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=320)
TEST_GIGA_PARAMS = dict(n_molecules=32, n_atoms=30, atom_prob=0.7, n_features=512, n_nu=320)

# reference implementation

Expand Down Expand Up @@ -434,6 +435,12 @@ def main(env_impl, sense_impl, feat_impl, args=None):

if use_verylarge_gpu:
if use_ultra:

print("-" * 80)
print("Giga systems:", TEST_GIGA_PARAMS)
tester.check_speed(
n_repetitions=20, data_size=TEST_GIGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against
)
print("-" * 80)
print("Ultra systems:", TEST_ULTRA_PARAMS)
tester.check_speed(
Expand Down
9 changes: 9 additions & 0 deletions hippynn/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,25 @@
from .database import Database
from .ondisk import DirectoryDatabase, NPZDatabase
has_ase = False
has_h5 = False

try:
import ase
has_ase = True
import h5py
has_h5 = True
except ImportError:
pass

if has_ase:
from ..interfaces.ase_interface import AseDatabase
if has_h5:
from .h5_pyanitools import PyAniFileDB, PyAniDirectoryDB

all_list = ["Database", "DirectoryDatabase", "NPZDatabase"]

if has_ase:
all_list += ["AseDatabase"]
if has_h5:
all_list += ["PyAniFileDB", "PyAniDirectoryDB"]
__all__ = all_list
Loading

0 comments on commit 0004728

Please sign in to comment.