Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 23, 2023
1 parent 60a9175 commit 5eee206
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 15 deletions.
3 changes: 2 additions & 1 deletion kgcnn/models_core/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self,
nested_model_config: bool = True,
is_physical_force: bool = True,
use_batch_jacobian: bool = True,
name: str = None
name: str = None,
outputs: Union[dict, list] = None
):

super().__init__()
Expand Down
9 changes: 7 additions & 2 deletions kgcnn/utils_core/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def plot_train_test_loss(histories: list, loss_name: str = None,

def plot_predict_true(y_predict, y_true, data_unit: list = None, model_name: str = "",
filepath: str = None, file_name: str = "", dataset_name: str = "", target_names: list = None,
figsize: list = None, dpi: float = None, show_fig: bool = False):
figsize: list = None, dpi: float = None, show_fig: bool = False,
scaled_predictions: bool = False):
r"""Make a scatter plot of predicted versus actual targets. Not for k-splits.
Args:
Expand All @@ -104,6 +105,7 @@ def plot_predict_true(y_predict, y_true, data_unit: list = None, model_name: str
figsize (list): Size of the figure. Default is None.
dpi (float): The resolution of the figure in dots-per-inch. Default is None.
show_fig (bool): Whether to show figure. Default is True.
scaled_predictions (bool): Whether predictions had been standardized. Default is False.
Returns:
matplotlib.pyplot.figure: Figure of the scatter plot.
Expand Down Expand Up @@ -141,7 +143,10 @@ def plot_predict_true(y_predict, y_true, data_unit: list = None, model_name: str
plt.plot(np.arange(*min_max, 0.05), np.arange(*min_max, 0.05), color='red')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title("Prediction of " + model_name + " for " + dataset_name)
plot_title = "Prediction of %s for %s " % (model_name, dataset_name)
if scaled_predictions:
plot_title = "(SCALED!) " + plot_title
plt.title(plot_title)
plt.legend(loc='upper left', fontsize='x-large')
if filepath is not None:
plt.savefig(os.path.join(filepath, model_name + "_" + dataset_name + "_" + file_name))
Expand Down
4 changes: 3 additions & 1 deletion training_core/hyper/hyper_md17.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
"use_output_mlp": False,
"output_mlp": None
}
}
},
"outputs": {"energy": {"name": "energy", "shape": (1,)},
"force": {"name": "force", "shape": (None, 3)}}
}
},
"training": {
Expand Down
17 changes: 8 additions & 9 deletions training_core/train_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
execute_folds = args["fold"] if "execute_folds" not in hyper["training"] else hyper["training"]["execute_folds"]
splits_done, current_split = 0, None
train_indices_all, test_indices_all = [], []
model = None
model, scaled_predictions = None, False
for current_split, (train_index, test_index) in enumerate(dataset.get_train_test_indices(train="train", test="test")):

# Keep list of train/test indices.
Expand Down Expand Up @@ -126,6 +126,7 @@
mae_metric_energy.set_scale(scaler_scale)
mae_metric_force.set_scale(scaler_scale)
scaled_metrics = {"energy": [mae_metric_energy], "force": [mae_metric_force]}
scaled_predictions = True

# Save scaler to file
scaler.save(os.path.join(filepath, f"scaler{postfix_file}_fold_{current_split}"))
Expand All @@ -145,12 +146,8 @@
print(" Compiled with jit: %s" % model._jit_compile) # noqa

# Convert targets into tensors.
labels_in_dataset = {
"energy": {"name": "energy"},
"force": {"name": "force", "shape": (None, 3)}
}
y_train = dataset_train.tensor(labels_in_dataset)
y_test = dataset_test.tensor(labels_in_dataset)
y_train = dataset_train.tensor(hyper["model"]["config"]["outputs"])
y_test = dataset_test.tensor(hyper["model"]["config"]["outputs"])

# Start and time training
start = time.time()
Expand All @@ -174,13 +171,15 @@
plot_predict_true(np.array(predicted_y["energy"]), np.array(true_y["energy"]),
filepath=filepath, data_unit=label_units,
model_name=hyper.model_name, dataset_name=hyper.dataset_class, target_names=label_names,
file_name=f"predict_energy{postfix_file}_fold_{splits_done}.png")
file_name=f"predict_energy{postfix_file}_fold_{splits_done}.png",
scaled_predictions=scaled_predictions)

plot_predict_true(np.concatenate([np.array(f) for f in predicted_y["force"]], axis=0),
np.concatenate([np.array(f) for f in true_y["force"]], axis=0),
filepath=filepath, data_unit=label_units,
model_name=hyper.model_name, dataset_name=hyper.dataset_class, target_names=label_names,
file_name=f"predict_force{postfix_file}_fold_{splits_done}.png")
file_name=f"predict_force{postfix_file}_fold_{splits_done}.png",
scaled_predictions=scaled_predictions)

# Save last keras-model to output-folder.
model.save(os.path.join(filepath, f"model{postfix_file}_fold_{current_split}.keras"))
Expand Down
6 changes: 4 additions & 2 deletions training_core/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
# Iterate over the cross-validation splits.
# Indices for train-test splits are stored in 'test_indices_list'.
execute_folds = args["fold"] if "execute_folds" not in hyper["training"] else hyper["training"]["execute_folds"]
model, current_split = None, None
model, current_split, scaled_predictions = None, None, False
train_indices_all, test_indices_all = [], []
for current_split, (train_index, test_index) in enumerate(dataset.get_train_test_indices(train="train", test="test")):

Expand Down Expand Up @@ -119,6 +119,7 @@
mae_metric.set_scale(scaler_scale)
rms_metric.set_scale(scaler_scale)
scaled_metrics = [mae_metric, rms_metric]
scaled_predictions = True

# Save scaler to file
scaler.save(os.path.join(filepath, f"scaler{postfix_file}_fold_{current_split}"))
Expand Down Expand Up @@ -162,7 +163,8 @@
plot_predict_true(predicted_y, true_y,
filepath=filepath, data_unit=label_units,
model_name=hyper.model_name, dataset_name=hyper.dataset_class, target_names=label_names,
file_name=f"predict{postfix_file}_fold_{current_split}.png", show_fig=False)
file_name=f"predict{postfix_file}_fold_{current_split}.png", show_fig=False,
scaled_predictions=scaled_predictions)

# Save last keras-model to output-folder.
model.save(os.path.join(filepath, f"model{postfix_file}_fold_{current_split}.keras"))
Expand Down

0 comments on commit 5eee206

Please sign in to comment.