diff --git a/flare/scripts/otf_train.py b/flare/scripts/otf_train.py index 899eee185..8bae02691 100644 --- a/flare/scripts/otf_train.py +++ b/flare/scripts/otf_train.py @@ -285,6 +285,8 @@ def get_sgp_calc(flare_config): sae_dct ), "'single_atom_energies' should be the same length as 'species'" single_atom_energies = {i: sae_dct[i] for i in range(n_species)} + else: + single_atom_energies = {i: 0 for i in range(n_species)} sgp = SGP_Wrapper( kernels=kernels,