Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 21, 2023
1 parent 425a829 commit 4c838d3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
3 changes: 3 additions & 0 deletions kgcnn/metrics_core/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import keras_core as ks
import keras_core.metrics
from keras_core import ops
import keras_core.saving


@ks.saving.register_keras_serializable(package='kgcnn', name='ScaledMeanAbsoluteError')
class ScaledMeanAbsoluteError(ks.metrics.MeanAbsoluteError):
"""Metric for a scaled mean absolute error (MAE), which can undo a pre-scaling of the targets. Only intended as
metric this allows to info the MAE with correct units or absolute values during fit."""
Expand Down Expand Up @@ -39,6 +41,7 @@ def set_scale(self, scale):
self.scale.assign(ops.cast(scale, dtype=scale.dtype))


@ks.saving.register_keras_serializable(package='kgcnn', name='ScaledRootMeanSquaredError')
class ScaledRootMeanSquaredError(ks.metrics.RootMeanSquaredError):
"""Metric for a scaled root mean squared error (RMSE), which can undo a pre-scaling of the targets.
Only intended as metric this allows to info the MAE with correct units or absolute values during fit."""
Expand Down
8 changes: 5 additions & 3 deletions training_core/hyper/hyper_esol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"output_embedding": "graph",
"output_mlp": {"use_bias": [True, True, False], "units": [140, 70, 1],
"activation": ["relu", "relu", "linear"]},
"output_scaling": {"name": "StandardLabelScaler"},
# "output_scaling": {"name": "StandardLabelScaler"},
}
},
"training": {
Expand All @@ -38,8 +38,10 @@
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}},
"loss": "mean_absolute_error",
"metrics": [{"class_name": "MeanAbsoluteError", "config": {"dtype": "float64"}},
{"class_name": "RootMeanSquaredError", "config": {"dtype": "float64"}}]
"metrics": [{"class_name": "kgcnn>ScaledMeanAbsoluteError",
"config": {"name": "mean_absolute_error"}},
{"class_name": "kgcnn>ScaledRootMeanSquaredError",
"config": {"name": "root_mean_squared_error"}}]
},
"scaler": {"class_name": "StandardLabelScaler", "module_name": "kgcnn.data.transform.scaler.standard",
"config": {"with_std": True, "with_mean": True, "copy": True}},
Expand Down
27 changes: 20 additions & 7 deletions training_core/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
import kgcnn.training_core.scheduler # noqa
from datetime import timedelta
import kgcnn.losses_core.losses
import kgcnn.metrics_core.metrics
from kgcnn.training_core.history import save_history_score, load_history_list, load_time_list
from kgcnn.data.transform.scaler.serial import deserialize
from kgcnn.utils_core.plots import plot_train_test_loss, plot_predict_true
Expand All @@ -19,8 +21,8 @@
# for training and model setup.
parser = argparse.ArgumentParser(description='Train a GNN on a graph regression or classification task.')
parser.add_argument("--hyper", required=False, help="Filepath to hyperparameter config file (.py or .json).",
default="hyper/hyper_qm9_energies.py")
parser.add_argument("--category", required=False, help="Graph model to train.", default="Schnet")
default="hyper/hyper_esol.py")
parser.add_argument("--category", required=False, help="Graph model to train.", default="GCN")
parser.add_argument("--model", required=False, help="Graph model to train.", default=None)
parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None)
parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)
Expand Down Expand Up @@ -104,13 +106,23 @@
print("Using Scaler to adjust output scale of model.")
scaler = deserialize(hyper["training"]["scaler"])
scaler.fit_dataset(dataset_train)
# Model requires to have a set_scale methode that accepts a scaler class.
if hasattr(model, "set_scaler"):
if hasattr(model, "set_scale"):
model.set_scale(scaler)
else:
assert np.all(np.isclose(scaler.get_scaling(), 1.)), "Change scaling is not supported."
scaler.transform(dataset_train)
scaler.transform(dataset_test)
scaler.transform(dataset_train, copy_dataset=True, copy=True)
scaler.transform(dataset_test, copy_dataset=True, copy=True)
# If scaler was used we add rescaled standard metrics to compile, since otherwise the keras history will not
# directly log the original target values, but the scaled ones.
scaler_scale = scaler.get_scaling()
for metric in model.metrics:
print(metric)
# Don't use scaled metrics if model has scale already.
if scaler_scale is not None:
if hasattr(metric, "set_scale"):
metric.set_scale(scaler_scale)

# Save scaler to file
scaler.save(os.path.join(filepath, f"scaler{postfix_file}_fold_{current_split}"))

x_train = dataset_train.tensor(hyper["model"]["config"]["inputs"])
y_train = np.array(dataset_train.get("graph_labels"))
Expand All @@ -133,6 +145,7 @@
os.path.join(filepath, f"time{postfix_file}_fold_{current_split}.pickle"))

# Plot prediction for the last split.
# Note that predicted values will not be rescaled.
predicted_y = model.predict(x_test)
true_y = y_test

Expand Down

0 comments on commit 4c838d3

Please sign in to comment.