diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index a0170d7..689c617 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -18,7 +18,7 @@ jobs:
- name: Set-up python version
uses: actions/setup-python@v5
with:
- python-version: "3.9" # supported by all tensorflow versions
+ python-version: "3.10" # supported by all tensorflow versions
- name: Install dependencies
run: |
pip install --upgrade pip
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 41ab229..58c8bda 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -15,6 +15,7 @@ jobs:
strategy:
matrix:
version:
+ - 2.16.1
- 2.15.0
- 2.14.0
- 2.13.0
diff --git a/README.md b/README.md
index 207d5b8..a68dce4 100644
--- a/README.md
+++ b/README.md
@@ -11,9 +11,10 @@
-
-
-
+
+
+
+
@@ -38,17 +39,17 @@
### Generative Adversarial Networks
-| Algorithms* | Avail | Test | Lipschitzianity** | Design inspired by | Tutorial |
-|:-----------:|:-----:|:----:|:-----------------:|:------------------:|:--------:|
-| [`GAN`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/GAN.py) | ✅ | ✅ | ❌ | [1][1], [8][8], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-GAN-LHCb_RICH.ipynb) |
-| [`BceGAN`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/BceGAN.py) | ✅ | ✅ | ❌ | [2][2], [8][8], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-BceGAN-LHCb_RICH.ipynb) |
-| [`LSGAN`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/LSGAN.py) | ✅ | ✅ | ❌ | [3][3], [8][8], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-LSGAN-LHCb_RICH.ipynb) |
-| [`WGAN`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/WGAN.py) | ✅ | ✅ | ✅ | [4][4], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-WGAN-LHCb_RICH.ipynb) |
-| [`WGAN_GP`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/WGAN_GP.py) | ✅ | ✅ | ✅ | [5][5], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-WGAN_GP-LHCb_RICH.ipynb) |
-| [`CramerGAN`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/CramerGAN.py) | ✅ | ✅ | ✅ | [6][6], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-CramerGAN-LHCb_RICH.ipynb) |
-| [`WGAN_ALP`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/WGAN_ALP.py) | ✅ | ✅ | ✅ | [7][7], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-WGAN_ALP-LHCb_RICH.ipynb) |
-| [`BceGAN_GP`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/BceGAN_GP.py) | ✅ | ✅ | ✅ | [2][2], [5][5], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-BceGAN_GP-LHCb_RICH.ipynb) |
-| [`BceGAN_ALP`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/BceGAN_ALP.py) | ✅ | ✅ | ✅ | [2][2], [7][7], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-BceGAN_ALP-LHCb_RICH.ipynb) |
+| Algorithms* | Source | Avail | Test | Lipschitzianity** | Refs | Tutorial |
+|:-----------:|:------:|:-----:|:----:|:-----------------:|:----:|:--------:|
+| GAN | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/GAN.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/GAN.py) | ✅ | ✅ | ❌ | [1][1], [8][8], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-GAN-LHCb_RICH.ipynb) |
+| BceGAN | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/BceGAN.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/BceGAN.py) | ✅ | ✅ | ❌ | [2][2], [8][8], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-BceGAN-LHCb_RICH.ipynb) |
+| LSGAN | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/LSGAN.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/LSGAN.py) | ✅ | ✅ | ❌ | [3][3], [8][8], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-LSGAN-LHCb_RICH.ipynb) |
+| WGAN | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/WGAN.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/WGAN.py) | ✅ | ✅ | ✅ | [4][4], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-WGAN-LHCb_RICH.ipynb) |
+| WGAN-GP | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/WGAN_GP.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/WGAN_GP.py) | ✅ | ✅ | ✅ | [5][5], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-WGAN_GP-LHCb_RICH.ipynb) |
+| CramerGAN | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/CramerGAN.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/CramerGAN.py) | ✅ | ✅ | ✅ | [6][6], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-CramerGAN-LHCb_RICH.ipynb) |
+| WGAN-ALP | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/WGAN_ALP.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/WGAN_ALP.py) | ✅ | ✅ | ✅ | [7][7], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-WGAN_ALP-LHCb_RICH.ipynb) |
+| BceGAN-GP | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/BceGAN_GP.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/BceGAN_GP.py) | ✅ | ✅ | ✅ | [2][2], [5][5], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-BceGAN_GP-LHCb_RICH.ipynb) |
+| BceGAN-ALP | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k2/BceGAN_ALP.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/algorithms/k3/BceGAN_ALP.py) | ✅ | ✅ | ✅ | [2][2], [7][7], [9][9] | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbarbetti/pidgan-notebooks/blob/main/tutorial-BceGAN_ALP-LHCb_RICH.ipynb) |
*each GAN algorithm is designed to operate taking __conditions__ as input [[10][10]]
@@ -56,29 +57,29 @@
### Generators
-| Players | Avail | Test | Inherit from | Design inspired by |
-|:-------:|:-----:|:----:|:------------:|:------------------:|
-| [`Generator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/generators/Generator.py) | ✅ | ✅ | [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) | [1][1], [10][10] |
-| [`ResGenerator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/generators/ResGenerator.py) | ✅ | ✅ | [`Generator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/generators/Generator.py) | [1][1], [10][10], [11][11] |
+| Players | Source | Avail | Test | Skip conn | Refs |
+|:-------:|:------:|:-----:|:----:|:---------:|:----:|
+| Generator | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/generators/k2/Generator.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/generators/k3/Generator.py) | ✅ | ✅ | ❌ | [1][1], [10][10] |
+| ResGenerator | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/generators/k2/ResGenerator.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/generators/k3/ResGenerator.py) | ✅ | ✅ | ✅ | [1][1], [10][10], [11][11] |
### Discriminators
-| Players | Avail | Test | Inherit from | Design inspired by |
-|:-------:|:-----:|:----:|:------------:|:------------------:|
-| [`Discriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/Discriminator.py) | ✅ | ✅ | [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) | [1][1], [9][9], [10][10] |
-| [`ResDiscriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/ResDiscriminator.py) | ✅ | ✅ | [`Discriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/Discriminator.py) | [1][1], [9][9], [10][10], [11][11] |
-| [`AuxDiscriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/AuxDiscriminator.py) | ✅ | ✅ | [`ResDiscriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/ResDiscriminator.py) | [1][1], [9][9], [10][10], [11][11], [12][12] |
+| Players | Source | Avail | Test | Skip conn | Aux proc | Refs |
+|:-------:|:------:|:-----:|:----:|:---------:|:--------:|:----:|
+| Discriminator | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/k2/Discriminator.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/k3/Discriminator.py) | ✅ | ✅ | ❌ | ❌ | [1][1], [9][9], [10][10] |
+| ResDiscriminator | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/k2/ResDiscriminator.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/k3/ResDiscriminator.py) | ✅ | ✅ | ✅ | ❌ | [1][1], [9][9], [10][10], [11][11] |
+| AuxDiscriminator | [`k2`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/k2/AuxDiscriminator.py)/[`k3`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/k3/AuxDiscriminator.py) | ✅ | ✅ | ✅ | ✅ | [1][1], [9][9], [10][10], [11][11], [12][12] |
### Other players
-| Players | Avail | Test | Inherit from |
-|:-------:|:-----:|:----:|:------------:|
-| [`Classifier`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/Classifier.py) | ✅ | ✅ | [`Discriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/Discriminator.py) |
-| [`ResClassifier`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/ResClassifier.py) | ✅ | ✅ | [`ResDiscriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/ResDiscriminator.py) |
-| [`AuxClassifier`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/AuxClassifier.py) | ✅ | ✅ | [`AuxDiscriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/AuxDiscriminator.py) |
-| [`MultiClassifier`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/MultiClassifier.py) | ✅ | ✅ | [`Discriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/Discriminator.py) |
-| [`MultiResClassifier`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/MultiResClassifier.py) | ✅ | ✅ | [`ResDiscriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/ResDiscriminator.py) |
-| [`AuxMultiClassifier`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/AuxMultiClassifier.py) | ✅ | ✅ | [`AuxDiscriminator`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/discriminators/AuxDiscriminator.py) |
+| Players | Source | Avail | Test | Skip conn | Aux proc | Multiclass |
+|:-------:|:------:|:-----:|:----:|:---------:|:--------:|:---------:|
+| Classifier | [`src`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/Classifier.py) | ✅ | ✅ | ❌ | ❌ | ❌ |
+| ResClassifier | [`src`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/ResClassifier.py) | ✅ | ✅ | ✅ | ❌ | ❌ |
+| AuxClassifier | [`src`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/AuxClassifier.py) | ✅ | ✅ | ✅ | ✅ | ❌ |
+| MultiClassifier | [`src`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/MultiClassifier.py) | ✅ | ✅ | ❌ | ❌ | ✅ |
+| MultiResClassifier | [`src`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/MultiResClassifier.py) | ✅ | ✅ | ✅ | ❌ | ✅ |
+| AuxMultiClassifier | [`src`](https://github.com/mbarbetti/pidgan/blob/main/src/pidgan/players/classifiers/AuxMultiClassifier.py) | ✅ | ✅ | ✅ | ✅ | ✅ |
### References
1. I.J. Goodfellow _et al._, "Generative Adversarial Networks", [arXiv:1406.2661][1]
diff --git a/pyproject.toml b/pyproject.toml
index c203452..e3b862d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,7 @@ name = "pidgan"
dynamic = ["version"]
description = "GAN-based models to flash-simulate the LHCb PID detectors"
readme = "README.md"
-requires-python = ">=3.7, <3.12"
+requires-python = ">=3.7, <3.13"
license = {text = "GPLv3 License"}
authors = [
{name = "Matteo Barbetti", email = "matteo.barbetti@cnaf.infn.it"},
@@ -15,6 +15,7 @@ authors = [
]
keywords = [
"tensorflow",
+ "keras",
"machine learning",
"deep learning",
"generative models",
@@ -25,32 +26,37 @@ keywords = [
"particle identification",
]
classifiers = [
- "Development Status :: 3 - Alpha",
+ "Development Status :: 4 - Beta",
+ "Environment :: GPU :: NVIDIA CUDA :: 11.2",
+ "Environment :: GPU :: NVIDIA CUDA :: 11.8",
+ "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.2",
+ "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.3",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
- "tensorflow>=2.7",
- "scikit-learn>=1.0",
+ "tensorflow>=2.8,<2.17",
+ "scikit-learn>=1.0,<1.6",
+ "numpy<2.0",
# "hopaas-client", # to be released on PyPI
]
[project.optional-dependencies]
hep = [
- "numpy",
- "scipy",
+ "matplotlib>=3.7,<4.0",
+ "html-reports>=0.2",
+ "scikinC>=0.2.6",
"pandas",
"uproot",
- "scikinC>=0.2.6",
- "matplotlib",
- "html-reports>=0.2",
"pyyaml",
+ "tqdm",
]
style = [
"ruff",
@@ -97,8 +103,8 @@ exclude = [
line-length = 88
indent-width = 4
-# Assume Python 3.8
-target-version = "py38"
+# Assume Python 3.10
+target-version = "py310"
# Enable linting and formatting for .ipynb files.
extend-include = ["*.ipynb"]
diff --git a/requirements/base.txt b/requirements/base.txt
index c5694b0..0511d46 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -1,3 +1,4 @@
-tensorflow>=2.7,<2.16
-scikit-learn>=1.0
+tensorflow>=2.8,<2.17
+scikit-learn>=1.0,<1.6
+numpy<2.0
#hopaas_client # to be released on PyPI
\ No newline at end of file
diff --git a/requirements/hep.txt b/requirements/hep.txt
index a9ae029..1870849 100644
--- a/requirements/hep.txt
+++ b/requirements/hep.txt
@@ -1,8 +1,7 @@
-tqdm
-numpy
+matplotlib>=3.7,<4.0
+html-reports>=0.2
+scikinC>=0.2.6
pandas
uproot
-scikinC>0.2.6
-matplotlib
-html-reports>=0.2
-pyyaml
\ No newline at end of file
+pyyaml
+tqdm
\ No newline at end of file
diff --git a/scripts/train_ANN_isMuon.py b/scripts/train_ANN_isMuon.py
index a9bcde3..d0b898f 100644
--- a/scripts/train_ANN_isMuon.py
+++ b/scripts/train_ANN_isMuon.py
@@ -1,22 +1,26 @@
import os
+import yaml
import pickle
-import socket
-from datetime import datetime
-
+import keras as k
import numpy as np
import tensorflow as tf
-import yaml
+
+from datetime import datetime
from html_reports import Report
from sklearn.utils import shuffle
-from tensorflow import keras
from utils.utils_argparser import argparser_training
-from utils.utils_training import prepare_training_plots, prepare_validation_plots
+from utils.utils_training import (
+ fill_html_report,
+ prepare_training_plots,
+ prepare_validation_plots,
+)
-import pidgan
from pidgan.callbacks.schedulers import LearnRateExpDecay
from pidgan.players.classifiers import ResClassifier
from pidgan.utils.preprocessing import invertColumnTransformer
-from pidgan.utils.reports import getSummaryHTML, initHPSingleton
+from pidgan.utils.reports import initHPSingleton
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
DTYPE = np.float32
BATCHSIZE = 2048
@@ -127,21 +131,18 @@
mlp_hidden_units=hp.get("mlp_hidden_units", 128),
mlp_hidden_activation=hp.get("mlp_hidden_activation", "relu"),
mlp_hidden_kernel_regularizer=hp.get(
- "mlp_hidden_kernel_regularizer", tf.keras.regularizers.L2(5e-5)
+ "mlp_hidden_kernel_regularizer", k.regularizers.L2(5e-5)
),
mlp_dropout_rates=hp.get("mlp_dropout_rates", 0.0),
name="classifier",
dtype=DTYPE,
)
-out = classifier(x[:BATCHSIZE])
-classifier.summary()
-
# +----------------------+
# | Optimizers setup |
# +----------------------+
-opt = keras.optimizers.Adam(hp.get("lr0", 0.001))
+opt = k.optimizers.Adam(hp.get("lr0", 0.001))
hp.get("optimizer", opt.name)
# +----------------------------+
@@ -149,10 +150,10 @@
# +----------------------------+
hp.get("loss", "binary cross-entropy")
-loss = keras.losses.BinaryCrossentropy(label_smoothing=hp.get("label_smoothing", 0.05))
+loss = k.losses.BinaryCrossentropy(label_smoothing=hp.get("label_smoothing", 0.05))
hp.get("metrics", ["auc"])
-metrics = [keras.metrics.AUC(name="auc")]
+metrics = [k.metrics.AUC(name="auc")]
classifier.compile(
optimizer=opt,
@@ -160,6 +161,9 @@
metrics=metrics,
)
+out = classifier(x[:BATCHSIZE])
+classifier.summary()
+
# +--------------------------+
# | Callbacks definition |
# +--------------------------+
@@ -239,8 +243,8 @@
os.makedirs(export_model_dirname)
if not os.path.exists(export_img_dirname):
os.makedirs(export_img_dirname) # need to save images
- keras.models.save_model(
- classifier.export_model,
+ k.models.save_model(
+ classifier.plain_keras,
filepath=f"{export_model_dirname}/saved_model",
save_format="tf",
)
@@ -264,50 +268,20 @@
# +---------------------+
report = Report()
-report.add_markdown('
isMuonANN training report
')
-
-info = [
- f"- Script executed on **{socket.gethostname()}**",
- f"- Model training completed in **{duration}**",
- f"- Model training executed with **TF{tf.__version__}** "
- f"and **pidgan v{pidgan.__version__}**",
- f"- Report generated on **{date}** at **{hour}**",
- f"- Model trained on **{args.particle}** tracks",
-]
-
-if "calib" not in args.data_sample:
- info += [f"- Model trained on **detailed simulated** samples ({args.data_sample})"]
-else:
- info += [f"- Model trained on **calibration** samples ({args.data_sample})"]
- if args.weights:
- info += ["- Any background components subtracted using **sWeights**"]
- else:
- info += ["- **sWeights not applied**"]
-
-report.add_markdown("\n".join([i for i in info]))
-report.add_markdown("---")
-
-## Hyperparameters and other details
-report.add_markdown('Hyperparameters and other details
')
-hyperparams = ""
-for k, v in hp.get_dict().items():
- hyperparams += f"- **{k}:** {v}\n"
-report.add_markdown(hyperparams)
-
-report.add_markdown("---")
-
-## Classifier architecture
-report.add_markdown('Classifier architecture
')
-report.add_markdown(f"**Model name:** {classifier.name}")
-html_table, params_details = getSummaryHTML(classifier.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
+## Basic report info
+fill_html_report(
+ report=report,
+ title="isMuonANN training report",
+ train_duration=duration,
+ report_datetime=(date, hour),
+ particle=args.particle,
+ data_sample=args.data_sample,
+ trained_with_weights=args.weights,
+ hp_dict=hp.get_dict(),
+ model_arch=[classifier],
+ model_labels=["Classifier"],
+)
## Training plots
prepare_training_plots(
diff --git a/scripts/train_GAN_GlobalPID-im.py b/scripts/train_GAN_GlobalPID-im.py
index 02bffa7..95b88c9 100644
--- a/scripts/train_GAN_GlobalPID-im.py
+++ b/scripts/train_GAN_GlobalPID-im.py
@@ -1,25 +1,29 @@
import os
+import yaml
import pickle
-import socket
-from datetime import datetime
-
+import keras as k
import numpy as np
import tensorflow as tf
-import yaml
+
+from datetime import datetime
from html_reports import Report
from sklearn.utils import shuffle
-from tensorflow import keras
from utils.utils_argparser import argparser_training
-from utils.utils_training import prepare_training_plots, prepare_validation_plots
+from utils.utils_training import (
+ fill_html_report,
+ prepare_training_plots,
+ prepare_validation_plots,
+)
-import pidgan
from pidgan.algorithms import BceGAN
from pidgan.callbacks.schedulers import LearnRateExpDecay
from pidgan.players.classifiers import Classifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
from pidgan.utils.preprocessing import invertColumnTransformer
-from pidgan.utils.reports import getSummaryHTML, initHPSingleton
+from pidgan.utils.reports import initHPSingleton
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
DTYPE = np.float32
BATCHSIZE = 2048
@@ -159,7 +163,7 @@
mlp_hidden_units=hp.get("r_mlp_hidden_units", 128),
mlp_hidden_activation=hp.get("r_mlp_hidden_activation", "relu"),
mlp_hidden_kernel_regularizer=hp.get(
- "r_mlp_hidden_kernel_regularizer", tf.keras.regularizers.L2(5e-5)
+ "r_mlp_hidden_kernel_regularizer", k.regularizers.L2(5e-5)
),
mlp_dropout_rates=hp.get("r_mlp_dropout_rates", 0.0),
name="referee",
@@ -177,21 +181,18 @@
)
hp.get("gan_name", gan.name)
-out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
-gan.summary()
-
# +----------------------+
# | Optimizers setup |
# +----------------------+
-g_opt = keras.optimizers.RMSprop(hp.get("g_lr0", 6e-4))
+g_opt = k.optimizers.RMSprop(hp.get("g_lr0", 6e-4))
hp.get("g_optimizer", g_opt.name)
-d_opt = keras.optimizers.RMSprop(hp.get("d_lr0", 1e-3))
+d_opt = k.optimizers.RMSprop(hp.get("d_lr0", 1e-3))
hp.get("d_optimizer", d_opt.name)
if gan.referee is not None:
- r_opt = keras.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
+ r_opt = k.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
hp.get("r_optimizer", r_opt.name)
# +----------------------------+
@@ -210,6 +211,9 @@
referee_upds_per_batch=hp.get("referee_upds_per_batch", 1) if gan.referee else None,
)
+out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
+gan.summary()
+
# +--------------------------+
# | Callbacks definition |
# +--------------------------+
@@ -319,8 +323,8 @@
os.makedirs(export_model_dirname)
if not os.path.exists(export_img_dirname):
os.makedirs(export_img_dirname) # need to save images
- keras.models.save_model(
- gan.generator.export_model,
+ k.models.save_model(
+ gan.generator.plain_keras,
filepath=f"{export_model_dirname}/saved_generator",
save_format="tf",
)
@@ -344,77 +348,20 @@
# +---------------------+
report = Report()
-report.add_markdown(
- 'GlobalPIDGAN (isMuon passed) training report
'
-)
-info = [
- f"- Script executed on **{socket.gethostname()}**",
- f"- Model training completed in **{duration}**",
- f"- Model training executed with **TF{tf.__version__}** "
- f"and **pidgan v{pidgan.__version__}**",
- f"- Report generated on **{date}** at **{hour}**",
- f"- Model trained on **{args.particle}** tracks",
-]
-
-if "calib" not in args.data_sample:
- info += [f"- Model trained on **detailed simulated** samples ({args.data_sample})"]
-else:
- info += [f"- Model trained on **calibration** samples ({args.data_sample})"]
- if args.weights:
- info += ["- Any background components subtracted using **sWeights**"]
- else:
- info += ["- **sWeights not applied**"]
-
-report.add_markdown("\n".join([i for i in info]))
-
-report.add_markdown("---")
-
-## Hyperparameters and other details
-report.add_markdown('Hyperparameters and other details
')
-hyperparams = ""
-for k, v in hp.get_dict().items():
- hyperparams += f"- **{k}:** {v}\n"
-report.add_markdown(hyperparams)
-
-report.add_markdown("---")
-
-## Generator architecture
-report.add_markdown('Generator architecture
')
-report.add_markdown(f"**Model name:** {gan.generator.name}")
-html_table, params_details = getSummaryHTML(gan.generator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Discriminator architecture
-report.add_markdown('Discriminator architecture
')
-report.add_markdown(f"**Model name:** {gan.discriminator.name}")
-html_table, params_details = getSummaryHTML(gan.discriminator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Referee architecture
-if gan.referee is not None:
- report.add_markdown('Referee architecture
')
- report.add_markdown(f"**Model name:** {gan.referee.name}")
- html_table, params_details = getSummaryHTML(gan.referee.export_model)
- model_weights = ""
- for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
- report.add_markdown(html_table)
- report.add_markdown(model_weights)
-
- report.add_markdown("---")
+## Basic report info
+fill_html_report(
+ report=report,
+ title="GlobalPIDGAN (isMuon passed) training report",
+ train_duration=duration,
+ report_datetime=(date, hour),
+ particle=args.particle,
+ data_sample=args.data_sample,
+ trained_with_weights=args.weights,
+ hp_dict=hp.get_dict(),
+ model_arch=[gan.generator, gan.discriminator, gan.referee],
+ model_labels=["Generator", "Discriminator", "Referee"],
+)
## Training plots
prepare_training_plots(
diff --git a/scripts/train_GAN_GlobalPID-nm.py b/scripts/train_GAN_GlobalPID-nm.py
index 94421ff..1a0a445 100644
--- a/scripts/train_GAN_GlobalPID-nm.py
+++ b/scripts/train_GAN_GlobalPID-nm.py
@@ -1,25 +1,29 @@
import os
+import yaml
import pickle
-import socket
-from datetime import datetime
-
+import keras as k
import numpy as np
import tensorflow as tf
-import yaml
+
+from datetime import datetime
from html_reports import Report
from sklearn.utils import shuffle
-from tensorflow import keras
from utils.utils_argparser import argparser_training
-from utils.utils_training import prepare_training_plots, prepare_validation_plots
+from utils.utils_training import (
+ fill_html_report,
+ prepare_training_plots,
+ prepare_validation_plots,
+)
-import pidgan
from pidgan.algorithms import BceGAN
from pidgan.callbacks.schedulers import LearnRateExpDecay
from pidgan.players.classifiers import Classifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
from pidgan.utils.preprocessing import invertColumnTransformer
-from pidgan.utils.reports import getSummaryHTML, initHPSingleton
+from pidgan.utils.reports import initHPSingleton
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
DTYPE = np.float32
BATCHSIZE = 2048
@@ -159,7 +163,7 @@
mlp_hidden_units=hp.get("r_mlp_hidden_units", 128),
mlp_hidden_activation=hp.get("r_mlp_hidden_activation", "relu"),
mlp_hidden_kernel_regularizer=hp.get(
- "r_mlp_hidden_kernel_regularizer", tf.keras.regularizers.L2(5e-5)
+ "r_mlp_hidden_kernel_regularizer", k.regularizers.L2(5e-5)
),
mlp_dropout_rates=hp.get("r_mlp_dropout_rates", 0.0),
name="referee",
@@ -177,21 +181,18 @@
)
hp.get("gan_name", gan.name)
-out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
-gan.summary()
-
# +----------------------+
# | Optimizers setup |
# +----------------------+
-g_opt = keras.optimizers.RMSprop(hp.get("g_lr0", 7e-4))
+g_opt = k.optimizers.RMSprop(hp.get("g_lr0", 7e-4))
hp.get("g_optimizer", g_opt.name)
-d_opt = keras.optimizers.RMSprop(hp.get("d_lr0", 5e-4))
+d_opt = k.optimizers.RMSprop(hp.get("d_lr0", 5e-4))
hp.get("d_optimizer", d_opt.name)
if gan.referee is not None:
- r_opt = keras.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
+ r_opt = k.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
hp.get("r_optimizer", r_opt.name)
# +----------------------------+
@@ -210,6 +211,9 @@
referee_upds_per_batch=hp.get("referee_upds_per_batch", 1) if gan.referee else None,
)
+out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
+gan.summary()
+
# +--------------------------+
# | Callbacks definition |
# +--------------------------+
@@ -319,8 +323,8 @@
os.makedirs(export_model_dirname)
if not os.path.exists(export_img_dirname):
os.makedirs(export_img_dirname) # need to save images
- keras.models.save_model(
- gan.generator.export_model,
+ k.models.save_model(
+ gan.generator.plain_keras,
filepath=f"{export_model_dirname}/saved_generator",
save_format="tf",
)
@@ -344,77 +348,20 @@
# +---------------------+
report = Report()
-report.add_markdown(
- 'GlobalPIDGAN (isMuon not passed) training report
'
-)
-info = [
- f"- Script executed on **{socket.gethostname()}**",
- f"- Model training completed in **{duration}**",
- f"- Model training executed with **TF{tf.__version__}** "
- f"and **pidgan v{pidgan.__version__}**",
- f"- Report generated on **{date}** at **{hour}**",
- f"- Model trained on **{args.particle}** tracks",
-]
-
-if "calib" not in args.data_sample:
- info += [f"- Model trained on **detailed simulated** samples ({args.data_sample})"]
-else:
- info += [f"- Model trained on **calibration** samples ({args.data_sample})"]
- if args.weights:
- info += ["- Any background components subtracted using **sWeights**"]
- else:
- info += ["- **sWeights not applied**"]
-
-report.add_markdown("\n".join([i for i in info]))
-
-report.add_markdown("---")
-
-## Hyperparameters and other details
-report.add_markdown('Hyperparameters and other details
')
-hyperparams = ""
-for k, v in hp.get_dict().items():
- hyperparams += f"- **{k}:** {v}\n"
-report.add_markdown(hyperparams)
-
-report.add_markdown("---")
-
-## Generator architecture
-report.add_markdown('Generator architecture
')
-report.add_markdown(f"**Model name:** {gan.generator.name}")
-html_table, params_details = getSummaryHTML(gan.generator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Discriminator architecture
-report.add_markdown('Discriminator architecture
')
-report.add_markdown(f"**Model name:** {gan.discriminator.name}")
-html_table, params_details = getSummaryHTML(gan.discriminator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Referee architecture
-if gan.referee is not None:
- report.add_markdown('Referee architecture
')
- report.add_markdown(f"**Model name:** {gan.referee.name}")
- html_table, params_details = getSummaryHTML(gan.referee.export_model)
- model_weights = ""
- for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
- report.add_markdown(html_table)
- report.add_markdown(model_weights)
-
- report.add_markdown("---")
+## Basic report info
+fill_html_report(
+ report=report,
+ title="GlobalPIDGAN (isMuon not passed) training report",
+ train_duration=duration,
+ report_datetime=(date, hour),
+ particle=args.particle,
+ data_sample=args.data_sample,
+ trained_with_weights=args.weights,
+ hp_dict=hp.get_dict(),
+ model_arch=[gan.generator, gan.discriminator, gan.referee],
+ model_labels=["Generator", "Discriminator", "Referee"],
+)
## Training plots
prepare_training_plots(
diff --git a/scripts/train_GAN_Muon.py b/scripts/train_GAN_Muon.py
index a7527fd..d5510cd 100644
--- a/scripts/train_GAN_Muon.py
+++ b/scripts/train_GAN_Muon.py
@@ -1,25 +1,29 @@
import os
+import yaml
import pickle
-import socket
-from datetime import datetime
-
+import keras as k
import numpy as np
import tensorflow as tf
-import yaml
+
+from datetime import datetime
from html_reports import Report
from sklearn.utils import shuffle
-from tensorflow import keras
from utils.utils_argparser import argparser_training
-from utils.utils_training import prepare_training_plots, prepare_validation_plots
+from utils.utils_training import (
+ fill_html_report,
+ prepare_training_plots,
+ prepare_validation_plots,
+)
-import pidgan
from pidgan.algorithms import BceGAN
from pidgan.callbacks.schedulers import LearnRateExpDecay
from pidgan.players.classifiers import Classifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
from pidgan.utils.preprocessing import invertColumnTransformer
-from pidgan.utils.reports import getSummaryHTML, initHPSingleton
+from pidgan.utils.reports import initHPSingleton
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
DTYPE = np.float32
BATCHSIZE = 2048
@@ -155,7 +159,7 @@
mlp_hidden_units=hp.get("r_mlp_hidden_units", 128),
mlp_hidden_activation=hp.get("r_mlp_hidden_activation", "relu"),
mlp_hidden_kernel_regularizer=hp.get(
- "r_mlp_hidden_kernel_regularizer", tf.keras.regularizers.L2(5e-5)
+ "r_mlp_hidden_kernel_regularizer", k.regularizers.L2(5e-5)
),
mlp_dropout_rates=hp.get("r_mlp_dropout_rates", 0.0),
name="referee",
@@ -173,21 +177,18 @@
)
hp.get("gan_name", gan.name)
-out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
-gan.summary()
-
# +----------------------+
# | Optimizers setup |
# +----------------------+
-g_opt = keras.optimizers.RMSprop(hp.get("g_lr0", 7e-4))
+g_opt = k.optimizers.RMSprop(hp.get("g_lr0", 7e-4))
hp.get("g_optimizer", g_opt.name)
-d_opt = keras.optimizers.RMSprop(hp.get("d_lr0", 5e-4))
+d_opt = k.optimizers.RMSprop(hp.get("d_lr0", 5e-4))
hp.get("d_optimizer", d_opt.name)
if gan.referee is not None:
- r_opt = keras.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
+ r_opt = k.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
hp.get("r_optimizer", r_opt.name)
# +----------------------------+
@@ -206,6 +207,9 @@
referee_upds_per_batch=hp.get("referee_upds_per_batch", 1) if gan.referee else None,
)
+out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
+gan.summary()
+
# +--------------------------+
# | Callbacks definition |
# +--------------------------+
@@ -315,8 +319,8 @@
os.makedirs(export_model_dirname)
if not os.path.exists(export_img_dirname):
os.makedirs(export_img_dirname) # need to save images
- keras.models.save_model(
- gan.generator.export_model,
+ k.models.save_model(
+ gan.generator.plain_keras,
filepath=f"{export_model_dirname}/saved_generator",
save_format="tf",
)
@@ -340,75 +344,20 @@
# +---------------------+
report = Report()
-report.add_markdown('MuonGAN training report
')
-
-info = [
- f"- Script executed on **{socket.gethostname()}**",
- f"- Model training completed in **{duration}**",
- f"- Model training executed with **TF{tf.__version__}** "
- f"and **pidgan v{pidgan.__version__}**",
- f"- Report generated on **{date}** at **{hour}**",
- f"- Model trained on **{args.particle}** tracks",
-]
-
-if "calib" not in args.data_sample:
- info += [f"- Model trained on **detailed simulated** samples ({args.data_sample})"]
-else:
- info += [f"- Model trained on **calibration** samples ({args.data_sample})"]
- if args.weights:
- info += ["- Any background components subtracted using **sWeights**"]
- else:
- info += ["- **sWeights not applied**"]
-
-report.add_markdown("\n".join([i for i in info]))
-
-report.add_markdown("---")
-
-## Hyperparameters and other details
-report.add_markdown('Hyperparameters and other details
')
-hyperparams = ""
-for k, v in hp.get_dict().items():
- hyperparams += f"- **{k}:** {v}\n"
-report.add_markdown(hyperparams)
-
-report.add_markdown("---")
-
-## Generator architecture
-report.add_markdown('Generator architecture
')
-report.add_markdown(f"**Model name:** {gan.generator.name}")
-html_table, params_details = getSummaryHTML(gan.generator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Discriminator architecture
-report.add_markdown('Discriminator architecture
')
-report.add_markdown(f"**Model name:** {gan.discriminator.name}")
-html_table, params_details = getSummaryHTML(gan.discriminator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Referee architecture
-if gan.referee is not None:
- report.add_markdown('Referee architecture
')
- report.add_markdown(f"**Model name:** {gan.referee.name}")
- html_table, params_details = getSummaryHTML(gan.referee.export_model)
- model_weights = ""
- for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
- report.add_markdown(html_table)
- report.add_markdown(model_weights)
-
- report.add_markdown("---")
+
+## Basic report info
+fill_html_report(
+ report=report,
+ title="MuonGAN training report",
+ train_duration=duration,
+ report_datetime=(date, hour),
+ particle=args.particle,
+ data_sample=args.data_sample,
+ trained_with_weights=args.weights,
+ hp_dict=hp.get_dict(),
+ model_arch=[gan.generator, gan.discriminator, gan.referee],
+ model_labels=["Generator", "Discriminator", "Referee"],
+)
## Training plots
prepare_training_plots(
diff --git a/scripts/train_GAN_Rich.py b/scripts/train_GAN_Rich.py
index 2bbed95..e74f391 100644
--- a/scripts/train_GAN_Rich.py
+++ b/scripts/train_GAN_Rich.py
@@ -1,25 +1,29 @@
import os
+import yaml
import pickle
-import socket
-from datetime import datetime
-
+import keras as k
import numpy as np
import tensorflow as tf
-import yaml
+
+from datetime import datetime
from html_reports import Report
from sklearn.utils import shuffle
-from tensorflow import keras
from utils.utils_argparser import argparser_training
-from utils.utils_training import prepare_training_plots, prepare_validation_plots
+from utils.utils_training import (
+ fill_html_report,
+ prepare_training_plots,
+ prepare_validation_plots,
+)
-import pidgan
from pidgan.algorithms import BceGAN
from pidgan.callbacks.schedulers import LearnRateExpDecay
from pidgan.players.classifiers import Classifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
from pidgan.utils.preprocessing import invertColumnTransformer
-from pidgan.utils.reports import getSummaryHTML, initHPSingleton
+from pidgan.utils.reports import initHPSingleton
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
DTYPE = np.float32
BATCHSIZE = 2048
@@ -159,7 +163,7 @@
mlp_hidden_units=hp.get("r_mlp_hidden_units", 128),
mlp_hidden_activation=hp.get("r_mlp_hidden_activation", "relu"),
mlp_hidden_kernel_regularizer=hp.get(
- "r_mlp_hidden_kernel_regularizer", tf.keras.regularizers.L2(5e-5)
+ "r_mlp_hidden_kernel_regularizer", k.regularizers.L2(5e-5)
),
mlp_dropout_rates=hp.get("r_mlp_dropout_rates", 0.0),
name="referee",
@@ -177,21 +181,18 @@
)
hp.get("gan_name", gan.name)
-out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
-gan.summary()
-
# +----------------------+
# | Optimizers setup |
# +----------------------+
-g_opt = keras.optimizers.RMSprop(hp.get("g_lr0", 4e-4))
+g_opt = k.optimizers.RMSprop(hp.get("g_lr0", 4e-4))
hp.get("g_optimizer", g_opt.name)
-d_opt = keras.optimizers.RMSprop(hp.get("d_lr0", 5e-4))
+d_opt = k.optimizers.RMSprop(hp.get("d_lr0", 5e-4))
hp.get("d_optimizer", d_opt.name)
if gan.referee is not None:
- r_opt = keras.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
+ r_opt = k.optimizers.RMSprop(hp.get("r_lr0", 1e-3))
hp.get("r_optimizer", r_opt.name)
# +----------------------------+
@@ -210,6 +211,9 @@
referee_upds_per_batch=hp.get("referee_upds_per_batch", 1) if gan.referee else None,
)
+out = gan(x[:BATCHSIZE], y[:BATCHSIZE])
+gan.summary()
+
# +--------------------------+
# | Callbacks definition |
# +--------------------------+
@@ -319,8 +323,8 @@
os.makedirs(export_model_dirname)
if not os.path.exists(export_img_dirname):
os.makedirs(export_img_dirname) # need to save images
- keras.models.save_model(
- gan.generator.export_model,
+ k.models.save_model(
+ gan.generator.plain_keras,
filepath=f"{export_model_dirname}/saved_generator",
save_format="tf",
)
@@ -344,75 +348,20 @@
# +---------------------+
report = Report()
-report.add_markdown('RichGAN training report
')
-
-info = [
- f"- Script executed on **{socket.gethostname()}**",
- f"- Model training completed in **{duration}**",
- f"- Model training executed with **TF{tf.__version__}** "
- f"and **pidgan v{pidgan.__version__}**",
- f"- Report generated on **{date}** at **{hour}**",
- f"- Model trained on **{args.particle}** tracks",
-]
-
-if "calib" not in args.data_sample:
- info += [f"- Model trained on **detailed simulated** samples ({args.data_sample})"]
-else:
- info += [f"- Model trained on **calibration** samples ({args.data_sample})"]
- if args.weights:
- info += ["- Any background components subtracted using **sWeights**"]
- else:
- info += ["- **sWeights not applied**"]
-
-report.add_markdown("\n".join([i for i in info]))
-
-report.add_markdown("---")
-
-## Hyperparameters and other details
-report.add_markdown('Hyperparameters and other details
')
-hyperparams = ""
-for k, v in hp.get_dict().items():
- hyperparams += f"- **{k}:** {v}\n"
-report.add_markdown(hyperparams)
-
-report.add_markdown("---")
-
-## Generator architecture
-report.add_markdown('Generator architecture
')
-report.add_markdown(f"**Model name:** {gan.generator.name}")
-html_table, params_details = getSummaryHTML(gan.generator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Discriminator architecture
-report.add_markdown('Discriminator architecture
')
-report.add_markdown(f"**Model name:** {gan.discriminator.name}")
-html_table, params_details = getSummaryHTML(gan.discriminator.export_model)
-model_weights = ""
-for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
-report.add_markdown(html_table)
-report.add_markdown(model_weights)
-
-report.add_markdown("---")
-
-## Referee architecture
-if gan.referee is not None:
- report.add_markdown('Referee architecture
')
- report.add_markdown(f"**Model name:** {gan.referee.name}")
- html_table, params_details = getSummaryHTML(gan.referee.export_model)
- model_weights = ""
- for k, n in zip(["Total", "Trainable", "Non-trainable"], params_details):
- model_weights += f"- **{k} params:** {n}\n"
- report.add_markdown(html_table)
- report.add_markdown(model_weights)
-
- report.add_markdown("---")
+
+## Basic report info
+fill_html_report(
+ report=report,
+ title="RichGAN training report",
+ train_duration=duration,
+ report_datetime=(date, hour),
+ particle=args.particle,
+ data_sample=args.data_sample,
+ trained_with_weights=args.weights,
+ hp_dict=hp.get_dict(),
+ model_arch=[gan.generator, gan.discriminator, gan.referee],
+ model_labels=["Generator", "Discriminator", "Referee"],
+)
## Training plots
prepare_training_plots(
diff --git a/scripts/utils/utils_plot.py b/scripts/utils/utils_plot.py
index c0aa688..16d21a1 100644
--- a/scripts/utils/utils_plot.py
+++ b/scripts/utils/utils_plot.py
@@ -1,10 +1,8 @@
-import copy
-
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
-my_cmap = copy.copy(mpl.cm.get_cmap("magma"))
+my_cmap = mpl.colormaps["magma"]
my_cmap.set_bad((0, 0, 0))
diff --git a/scripts/utils/utils_training.py b/scripts/utils/utils_training.py
index 11154c6..6e72bd7 100644
--- a/scripts/utils/utils_training.py
+++ b/scripts/utils/utils_training.py
@@ -1,4 +1,11 @@
+import socket
+import pidgan
+import keras as k
import numpy as np
+import tensorflow as tf
+
+from html_reports import Report
+from pidgan.utils.reports import getSummaryHTML
from .utils_plot import (
binned_validation_histogram,
@@ -24,6 +31,69 @@
}
+def fill_html_report(
+ report,
+ title,
+ train_duration,
+ report_datetime,
+ particle,
+ data_sample,
+ trained_with_weights,
+ hp_dict,
+ model_arch,
+ model_labels,
+) -> Report:
+ report.add_markdown(f'{title}
')
+
+ ## General information
+ date, hour = report_datetime
+ info = [
+ f"- Script executed on **{socket.gethostname()}**",
+ f"- Model training completed in **{train_duration}**",
+ f"- Model training executed with **TF{tf.__version__}** "
+ f"(Keras {k.__version__}) and **pidgan {pidgan.__version__}**",
+ f"- Report generated on **{date}** at **{hour}**",
+ f"- Model trained on **{particle}** tracks",
+ ]
+
+ if "calib" not in data_sample:
+ info += [f"- Model trained on **detailed simulated** samples ({data_sample})"]
+ else:
+ info += [f"- Model trained on **calibration** samples ({data_sample})"]
+ if trained_with_weights:
+ info += ["- Any background components subtracted using **sWeights**"]
+ else:
+ info += ["- **sWeights not applied**"]
+
+ report.add_markdown("\n".join([i for i in info]))
+ report.add_markdown("---")
+
+ ## Hyperparameters and other details
+ report.add_markdown('Hyperparameters and other details
')
+ hyperparams = ""
+ for key, val in hp_dict.items():
+ hyperparams += f"- **{key}:** {val}\n"
+ report.add_markdown(hyperparams)
+ report.add_markdown("---")
+
+ ## Models architecture
+ for model, label in zip(model_arch, model_labels):
+ if model is not None:
+ report.add_markdown(f'{label} architecture
')
+ report.add_markdown(f"**Model name:** {model.name}")
+ html_table, params_details = getSummaryHTML(model.plain_keras)
+ model_weights = ""
+ for key, num in zip(
+ ["Total", "Trainable", "Non-trainable"], params_details
+ ):
+ model_weights += f"- **{key} params:** {num}\n"
+ report.add_markdown(html_table)
+ report.add_markdown(model_weights)
+ report.add_markdown("---")
+
+ return report
+
+
def prepare_training_plots(
report,
model,
diff --git a/src/pidgan/algorithms/__init__.py b/src/pidgan/algorithms/__init__.py
index 7ce94b3..2efe00b 100644
--- a/src/pidgan/algorithms/__init__.py
+++ b/src/pidgan/algorithms/__init__.py
@@ -1,9 +1,24 @@
-from .BceGAN import BceGAN
-from .BceGAN_ALP import BceGAN_ALP
-from .BceGAN_GP import BceGAN_GP
-from .CramerGAN import CramerGAN
-from .GAN import GAN
-from .LSGAN import LSGAN
-from .WGAN import WGAN
-from .WGAN_ALP import WGAN_ALP
-from .WGAN_GP import WGAN_GP
+import keras as k
+
+v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+
+if v_major == 3 and v_minor >= 0:
+ from .k3.BceGAN import BceGAN
+ from .k3.BceGAN_ALP import BceGAN_ALP
+ from .k3.BceGAN_GP import BceGAN_GP
+ from .k3.CramerGAN import CramerGAN
+ from .k3.GAN import GAN
+ from .k3.LSGAN import LSGAN
+ from .k3.WGAN import WGAN
+ from .k3.WGAN_ALP import WGAN_ALP
+ from .k3.WGAN_GP import WGAN_GP
+else:
+ from .k2.BceGAN import BceGAN
+ from .k2.BceGAN_ALP import BceGAN_ALP
+ from .k2.BceGAN_GP import BceGAN_GP
+ from .k2.CramerGAN import CramerGAN
+ from .k2.GAN import GAN
+ from .k2.LSGAN import LSGAN
+ from .k2.WGAN import WGAN
+ from .k2.WGAN_ALP import WGAN_ALP
+ from .k2.WGAN_GP import WGAN_GP
diff --git a/src/pidgan/algorithms/BceGAN.py b/src/pidgan/algorithms/k2/BceGAN.py
similarity index 96%
rename from src/pidgan/algorithms/BceGAN.py
rename to src/pidgan/algorithms/k2/BceGAN.py
index a577c5c..b514f81 100644
--- a/src/pidgan/algorithms/BceGAN.py
+++ b/src/pidgan/algorithms/k2/BceGAN.py
@@ -1,7 +1,7 @@
+import keras as k
import tensorflow as tf
-from tensorflow import keras
-from pidgan.algorithms.GAN import GAN
+from pidgan.algorithms.k2.GAN import GAN
class BceGAN(GAN):
@@ -30,7 +30,7 @@ def __init__(
self._use_original_loss = None
# Keras BinaryCrossentropy
- self._bce_loss = keras.losses.BinaryCrossentropy(
+ self._bce_loss = k.losses.BinaryCrossentropy(
from_logits=from_logits, label_smoothing=label_smoothing
)
self._from_logits = bool(from_logits)
diff --git a/src/pidgan/algorithms/BceGAN_ALP.py b/src/pidgan/algorithms/k2/BceGAN_ALP.py
similarity index 95%
rename from src/pidgan/algorithms/BceGAN_ALP.py
rename to src/pidgan/algorithms/k2/BceGAN_ALP.py
index cd3d185..0199d50 100644
--- a/src/pidgan/algorithms/BceGAN_ALP.py
+++ b/src/pidgan/algorithms/k2/BceGAN_ALP.py
@@ -1,7 +1,7 @@
import tensorflow as tf
-from pidgan.algorithms.BceGAN_GP import BceGAN_GP
-from pidgan.algorithms.lipschitz_regularizations import (
+from pidgan.algorithms.k2.BceGAN_GP import BceGAN_GP
+from pidgan.algorithms.k2.lipschitz_regularizations import (
compute_AdversarialLipschitzPenalty,
)
diff --git a/src/pidgan/algorithms/BceGAN_GP.py b/src/pidgan/algorithms/k2/BceGAN_GP.py
similarity index 96%
rename from src/pidgan/algorithms/BceGAN_GP.py
rename to src/pidgan/algorithms/k2/BceGAN_GP.py
index 1e55972..41149c4 100644
--- a/src/pidgan/algorithms/BceGAN_GP.py
+++ b/src/pidgan/algorithms/k2/BceGAN_GP.py
@@ -1,7 +1,7 @@
import tensorflow as tf
-from pidgan.algorithms import BceGAN
-from pidgan.algorithms.lipschitz_regularizations import (
+from pidgan.algorithms.k2.BceGAN import BceGAN
+from pidgan.algorithms.k2.lipschitz_regularizations import (
PENALTY_STRATEGIES,
compute_GradientPenalty,
)
diff --git a/src/pidgan/algorithms/CramerGAN.py b/src/pidgan/algorithms/k2/CramerGAN.py
similarity index 94%
rename from src/pidgan/algorithms/CramerGAN.py
rename to src/pidgan/algorithms/k2/CramerGAN.py
index 728d5c2..53da11a 100644
--- a/src/pidgan/algorithms/CramerGAN.py
+++ b/src/pidgan/algorithms/k2/CramerGAN.py
@@ -1,7 +1,7 @@
import tensorflow as tf
-from pidgan.algorithms.lipschitz_regularizations import compute_CriticGradientPenalty
-from pidgan.algorithms.WGAN_GP import WGAN_GP
+from pidgan.algorithms.k2.lipschitz_regularizations import compute_CriticGradientPenalty
+from pidgan.algorithms.k2.WGAN_GP import WGAN_GP
LIPSCHITZ_CONSTANT = 1.0
@@ -44,10 +44,13 @@ def __init__(
self._critic = Critic(lambda x, t: self._discriminator(x, training=t))
def _update_metric_states(self, x, y, sample_weight) -> None:
- metric_states = dict(g_loss=self._g_loss.result(), d_loss=self._d_loss.result())
+ metric_states = {
+ "g_loss": self._g_loss_state.result(),
+ "d_loss": self._d_loss_state.result(),
+ }
if self._referee is not None:
- metric_states.update(dict(r_loss=self._r_loss.result()))
- if self._metrics is not None:
+ metric_states.update({"r_loss": self._r_loss_state.result()})
+ if self._train_metrics is not None:
batch_size = tf.cast(tf.shape(x)[0] / 2, tf.int32)
x_1, x_2 = tf.split(x[: batch_size * 2], 2, axis=0)
y_1 = y[:batch_size]
@@ -63,7 +66,7 @@ def _update_metric_states(self, x, y, sample_weight) -> None:
(x_concat_1, y_concat_1), (x_concat_2, y_concat_2), training=False
)
c_ref, c_gen = tf.split(c_out, 2, axis=0)
- for metric in self._metrics:
+ for metric in self._train_metrics:
if sample_weight is not None:
w_1, w_2 = tf.split(sample_weight[: batch_size * 2], 2, axis=0)
weights = w_1 * w_2
diff --git a/src/pidgan/algorithms/GAN.py b/src/pidgan/algorithms/k2/GAN.py
similarity index 88%
rename from src/pidgan/algorithms/GAN.py
rename to src/pidgan/algorithms/k2/GAN.py
index f9854d2..aef9676 100644
--- a/src/pidgan/algorithms/GAN.py
+++ b/src/pidgan/algorithms/k2/GAN.py
@@ -1,5 +1,7 @@
+import warnings
+
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import Classifier
from pidgan.players.discriminators import Discriminator
@@ -10,7 +12,7 @@
MAX_LOG_VALUE = 1.0
-class GAN(keras.Model):
+class GAN(k.Model):
def __init__(
self,
generator,
@@ -25,21 +27,26 @@ def __init__(
super().__init__(name=name, dtype=dtype)
self._loss_name = "GAN original loss"
+ self._train_metrics = None
+ self._model_is_built = False
+
# Generator
if not isinstance(generator, Generator):
raise TypeError(
- f"`generator` should be a pidgan's `Generator`, "
+ f"`generator` should be a pidgan's Generator, "
f"instead {type(generator)} passed"
)
self._generator = generator
+ self._g_loss_state = k.metrics.Mean(name="g_loss")
# Discriminator
if not isinstance(discriminator, Discriminator):
raise TypeError(
- f"`discriminator` should be a pidgan's `Discriminator`, "
+ f"`discriminator` should be a pidgan's Discriminator, "
f"instead {type(discriminator)} passed"
)
self._discriminator = discriminator
+ self._d_loss_state = k.metrics.Mean(name="d_loss")
# Flag to use the original loss
assert isinstance(use_original_loss, bool)
@@ -50,12 +57,13 @@ def __init__(
if not isinstance(referee, Classifier):
if not isinstance(referee, Discriminator):
raise TypeError(
- f"`referee` should be a pidgan's `Classifier` "
- f"(or `Discriminator`), instead "
+ f"`referee` should be a pidgan's Classifier "
+ f"(or Discriminator), instead "
f"{type(referee)} passed"
)
self._referee = referee
- self._referee_loss = keras.losses.BinaryCrossentropy()
+ self._referee_loss = k.losses.BinaryCrossentropy()
+ self._r_loss_state = k.metrics.Mean(name="r_loss")
else:
self._referee = None
self._referee_loss = None
@@ -70,30 +78,9 @@ def __init__(
assert feature_matching_penalty >= 0.0
self._feature_matching_penalty = float(feature_matching_penalty)
- def call(self, x, y=None) -> tuple:
- g_out = self._generator(x)
- d_out_gen = self._discriminator((x, g_out))
- if y is None:
- if self._referee is not None:
- r_out_gen = self._referee((x, g_out))
- return g_out, d_out_gen, r_out_gen
- else:
- return g_out, d_out_gen
- else:
- d_out_ref = self._discriminator((x, y))
- if self._referee is not None:
- r_out_gen = self._referee((x, g_out))
- r_out_ref = self._referee((x, y))
- return g_out, (d_out_gen, d_out_ref), (r_out_gen, r_out_ref)
- else:
- return g_out, (d_out_gen, d_out_ref)
-
- def summary(self, **kwargs) -> None:
- print("_" * 65)
- self._generator.summary(**kwargs)
- self._discriminator.summary(**kwargs)
- if self._referee is not None:
- self._referee.summary(**kwargs)
+ def build(self, input_shape) -> None:
+ super().build(input_shape=input_shape)
+ self._model_is_built = True
def compile(
self,
@@ -107,12 +94,19 @@ def compile(
) -> None:
super().compile(weighted_metrics=[])
- # Loss metrics
- self._g_loss = keras.metrics.Mean(name="g_loss")
- self._d_loss = keras.metrics.Mean(name="d_loss")
- if self._referee is not None:
- self._r_loss = keras.metrics.Mean(name="r_loss")
- self._metrics = checkMetrics(metrics)
+ # Metrics
+ if not self._model_is_built:
+ self._train_metrics = checkMetrics(metrics)
+ self.build(input_shape=[])
+ else:
+ if metrics is not None:
+ warnings.warn(
+ "The `metrics` argument is ignored when the model is "
+ "built before to be compiled. Consider to move the first model "
+ "calling after the `compile()` method to fix this issue.",
+ category=UserWarning,
+ stacklevel=1,
+ )
# Gen/Disc optimizers
self._g_opt = checkOptimizer(generator_optimizer)
@@ -144,6 +138,31 @@ def compile(
self._r_opt = None
self._r_upds_per_batch = None
+ def call(self, x, y=None) -> tuple:
+ g_out = self._generator(x)
+ d_out_gen = self._discriminator((x, g_out))
+ if y is None:
+ if self._referee is not None:
+ r_out_gen = self._referee((x, g_out))
+ return g_out, d_out_gen, r_out_gen
+ else:
+ return g_out, d_out_gen
+ else:
+ d_out_ref = self._discriminator((x, y))
+ if self._referee is not None:
+ r_out_gen = self._referee((x, g_out))
+ r_out_ref = self._referee((x, y))
+ return g_out, (d_out_gen, d_out_ref), (r_out_gen, r_out_ref)
+ else:
+ return g_out, (d_out_gen, d_out_ref)
+
+ def summary(self, **kwargs) -> None:
+ print("_" * 65)
+ self._generator.summary(**kwargs)
+ self._discriminator.summary(**kwargs)
+ if self._referee is not None:
+ self._referee.summary(**kwargs)
+
def train_step(self, data) -> dict:
x, y, sample_weight = self._unpack_data(data)
@@ -175,7 +194,7 @@ def _g_train_step(self, x, y, sample_weight=None) -> None:
self._g_opt.apply_gradients(zip(gradients, trainable_vars))
threshold = self._compute_threshold(self._discriminator, x, y, sample_weight)
- self._g_loss.update_state(loss + threshold)
+ self._g_loss_state.update_state(loss + threshold)
def _d_train_step(self, x, y, sample_weight=None) -> None:
with tf.GradientTape() as tape:
@@ -186,7 +205,7 @@ def _d_train_step(self, x, y, sample_weight=None) -> None:
self._d_opt.apply_gradients(zip(gradients, trainable_vars))
threshold = self._compute_threshold(self._discriminator, x, y, sample_weight)
- self._d_loss.update_state(loss - threshold)
+ self._d_loss_state.update_state(loss - threshold)
def _r_train_step(self, x, y, sample_weight=None) -> None:
with tf.GradientTape() as tape:
@@ -196,19 +215,22 @@ def _r_train_step(self, x, y, sample_weight=None) -> None:
gradients = tape.gradient(loss, trainable_vars)
self._r_opt.apply_gradients(zip(gradients, trainable_vars))
- self._r_loss.update_state(loss)
+ self._r_loss_state.update_state(loss)
def _update_metric_states(self, x, y, sample_weight) -> None:
- metric_states = dict(g_loss=self._g_loss.result(), d_loss=self._d_loss.result())
+ metric_states = {
+ "g_loss": self._g_loss_state.result(),
+ "d_loss": self._d_loss_state.result(),
+ }
if self._referee is not None:
- metric_states.update(dict(r_loss=self._r_loss.result()))
- if self._metrics is not None:
+ metric_states.update({"r_loss": self._r_loss_state.result()})
+ if self._train_metrics is not None:
g_out = self._generator(x, training=False)
x_concat = tf.concat([x, x], axis=0)
y_concat = tf.concat([y, g_out], axis=0)
d_out = self._discriminator((x_concat, y_concat), training=False)
d_ref, d_gen = tf.split(d_out, 2, axis=0)
- for metric in self._metrics:
+ for metric in self._train_metrics:
metric.update_state(
y_true=d_ref, y_pred=d_gen, sample_weight=sample_weight
)
@@ -399,16 +421,16 @@ def test_step(self, data) -> dict:
threshold = self._compute_threshold(self._discriminator, x, y, sample_weight)
g_loss = self._compute_g_loss(x, y, sample_weight, training=False, test=True)
- self._g_loss.update_state(g_loss + threshold)
+ self._g_loss_state.update_state(g_loss + threshold)
d_loss = self._compute_d_loss(x, y, sample_weight, training=False, test=True)
- self._d_loss.update_state(d_loss - threshold)
+ self._d_loss_state.update_state(d_loss - threshold)
if self._referee is not None:
r_loss = self._compute_r_loss(
x, y, sample_weight, training=False, test=True
)
- self._r_loss.update_state(r_loss)
+ self._r_loss_state.update_state(r_loss)
return self._update_metric_states(x, y, sample_weight)
@@ -445,19 +467,14 @@ def referee(self): # TODO: add Union[None, Discriminator]
@property
def metrics(self) -> list:
- reset_states = [self._g_loss, self._d_loss]
- if self._referee is not None:
- reset_states += [self._r_loss]
- if self._metrics is not None:
- reset_states += self._metrics
- return reset_states
+ return self._metrics
@property
- def generator_optimizer(self) -> keras.optimizers.Optimizer:
+ def generator_optimizer(self) -> k.optimizers.Optimizer:
return self._g_opt
@property
- def discriminator_optimizer(self) -> keras.optimizers.Optimizer:
+ def discriminator_optimizer(self) -> k.optimizers.Optimizer:
return self._d_opt
@property
diff --git a/src/pidgan/algorithms/LSGAN.py b/src/pidgan/algorithms/k2/LSGAN.py
similarity index 98%
rename from src/pidgan/algorithms/LSGAN.py
rename to src/pidgan/algorithms/k2/LSGAN.py
index dd4ae5a..93fcf7f 100644
--- a/src/pidgan/algorithms/LSGAN.py
+++ b/src/pidgan/algorithms/k2/LSGAN.py
@@ -1,6 +1,6 @@
import tensorflow as tf
-from pidgan.algorithms.GAN import GAN
+from pidgan.algorithms.k2.GAN import GAN
class LSGAN(GAN):
diff --git a/src/pidgan/algorithms/WGAN.py b/src/pidgan/algorithms/k2/WGAN.py
similarity index 98%
rename from src/pidgan/algorithms/WGAN.py
rename to src/pidgan/algorithms/k2/WGAN.py
index 15a7efb..ecb2e11 100644
--- a/src/pidgan/algorithms/WGAN.py
+++ b/src/pidgan/algorithms/k2/WGAN.py
@@ -1,6 +1,6 @@
import tensorflow as tf
-from pidgan.algorithms.GAN import GAN
+from pidgan.algorithms.k2.GAN import GAN
class WGAN(GAN):
diff --git a/src/pidgan/algorithms/WGAN_ALP.py b/src/pidgan/algorithms/k2/WGAN_ALP.py
similarity index 96%
rename from src/pidgan/algorithms/WGAN_ALP.py
rename to src/pidgan/algorithms/k2/WGAN_ALP.py
index 3c64ac6..8b0a63e 100644
--- a/src/pidgan/algorithms/WGAN_ALP.py
+++ b/src/pidgan/algorithms/k2/WGAN_ALP.py
@@ -1,9 +1,9 @@
import tensorflow as tf
-from pidgan.algorithms.lipschitz_regularizations import (
+from pidgan.algorithms.k2.lipschitz_regularizations import (
compute_AdversarialLipschitzPenalty,
)
-from pidgan.algorithms.WGAN_GP import WGAN_GP
+from pidgan.algorithms.k2.WGAN_GP import WGAN_GP
LIPSCHITZ_CONSTANT = 1.0
XI_MIN = 0.8
diff --git a/src/pidgan/algorithms/WGAN_GP.py b/src/pidgan/algorithms/k2/WGAN_GP.py
similarity index 96%
rename from src/pidgan/algorithms/WGAN_GP.py
rename to src/pidgan/algorithms/k2/WGAN_GP.py
index 7499ac8..019f5ac 100644
--- a/src/pidgan/algorithms/WGAN_GP.py
+++ b/src/pidgan/algorithms/k2/WGAN_GP.py
@@ -1,10 +1,10 @@
import tensorflow as tf
-from pidgan.algorithms.lipschitz_regularizations import (
+from pidgan.algorithms.k2.lipschitz_regularizations import (
PENALTY_STRATEGIES,
compute_GradientPenalty,
)
-from pidgan.algorithms.WGAN import WGAN
+from pidgan.algorithms.k2.WGAN import WGAN
LIPSCHITZ_CONSTANT = 1.0
diff --git a/src/pidgan/algorithms/k2/__init__.py b/src/pidgan/algorithms/k2/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/algorithms/lipschitz_regularizations.py b/src/pidgan/algorithms/k2/lipschitz_regularizations.py
similarity index 100%
rename from src/pidgan/algorithms/lipschitz_regularizations.py
rename to src/pidgan/algorithms/k2/lipschitz_regularizations.py
diff --git a/src/pidgan/algorithms/k3/BceGAN.py b/src/pidgan/algorithms/k3/BceGAN.py
new file mode 100644
index 0000000..09c750b
--- /dev/null
+++ b/src/pidgan/algorithms/k3/BceGAN.py
@@ -0,0 +1,95 @@
+import keras as k
+
+from pidgan.algorithms.k3.GAN import GAN
+
+
+class BceGAN(GAN):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ from_logits=False,
+ label_smoothing=0.0,
+ injected_noise_stddev=0.0,
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="BceGAN",
+ dtype=None,
+ ):
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ injected_noise_stddev=injected_noise_stddev,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+ self._loss_name = "Binary cross-entropy"
+ self._use_original_loss = None
+
+ # Keras BinaryCrossentropy
+ self._bce_loss = k.losses.BinaryCrossentropy(
+ from_logits=from_logits, label_smoothing=label_smoothing
+ )
+ self._from_logits = bool(from_logits)
+ self._label_smoothing = float(label_smoothing)
+
+ def _compute_g_loss(self, x, y, sample_weight=None, training=True, test=False):
+ _, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=training
+ )
+ x_gen, y_gen, w_gen = trainset_gen
+
+ if self._inj_noise_std > 0.0:
+ rnd_gen = k.random.normal(
+ shape=k.ops.shape(y_gen),
+ mean=0.0,
+ stddev=self._inj_noise_std,
+ dtype=y_gen.dtype,
+ )
+ y_gen += rnd_gen
+
+ d_out_gen = self._discriminator((x_gen, y_gen), training=False)
+
+ fake_loss = self._bce_loss(
+ k.ops.ones_like(d_out_gen), d_out_gen, sample_weight=w_gen
+ )
+ return fake_loss
+
+ def _compute_d_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ x_ref, y_ref, w_ref = trainset_ref
+ x_gen, y_gen, w_gen = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ if self._inj_noise_std > 0.0:
+ rnd_noise = k.random.normal(
+ shape=k.ops.shape(y_concat),
+ mean=0.0,
+ stddev=self._inj_noise_std,
+ dtype=y_concat.dtype,
+ )
+ y_concat += rnd_noise
+
+ d_out = self._discriminator((x_concat, y_concat), training=training)
+ d_ref, d_gen = k.ops.split(d_out, 2, axis=0)
+
+ real_loss = self._bce_loss(k.ops.ones_like(d_ref), d_ref, sample_weight=w_ref)
+ fake_loss = self._bce_loss(k.ops.zeros_like(d_gen), d_gen, sample_weight=w_gen)
+ return (real_loss + fake_loss) / 2.0
+
+ def _compute_threshold(self, model, x, y, sample_weight=None):
+ return 0.0
+
+ @property
+ def from_logits(self) -> bool:
+ return self._from_logits
+
+ @property
+ def label_smoothing(self) -> float:
+ return self._label_smoothing
diff --git a/src/pidgan/algorithms/k3/BceGAN_ALP.py b/src/pidgan/algorithms/k3/BceGAN_ALP.py
new file mode 100644
index 0000000..c7a69a8
--- /dev/null
+++ b/src/pidgan/algorithms/k3/BceGAN_ALP.py
@@ -0,0 +1,81 @@
+from pidgan.algorithms.k3.BceGAN_GP import BceGAN_GP
+from pidgan.algorithms.k3.lipschitz_regularizations import (
+ compute_AdversarialLipschitzPenalty,
+)
+
+LIPSCHITZ_CONSTANT = 1.0
+XI_MIN = 0.8
+XI_MAX = 1.2
+
+
+class BceGAN_ALP(BceGAN_GP):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="one-sided",
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="BceGAN-ALP",
+ dtype=None,
+ ):
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ lipschitz_penalty=lipschitz_penalty,
+ lipschitz_penalty_strategy=lipschitz_penalty_strategy,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+
+ def compile(
+ self,
+ metrics=None,
+ generator_optimizer="rmsprop",
+ discriminator_optimizer="rmsprop",
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ virtual_adv_direction_upds=1,
+ referee_optimizer=None,
+ referee_upds_per_batch=None,
+ ) -> None:
+ super().compile(
+ metrics=metrics,
+ generator_optimizer=generator_optimizer,
+ discriminator_optimizer=discriminator_optimizer,
+ generator_upds_per_batch=generator_upds_per_batch,
+ discriminator_upds_per_batch=discriminator_upds_per_batch,
+ referee_optimizer=referee_optimizer,
+ referee_upds_per_batch=referee_upds_per_batch,
+ )
+
+ # Virtual adversarial direction updates
+ assert isinstance(virtual_adv_direction_upds, (int, float))
+ assert virtual_adv_direction_upds > 0
+ self._vir_adv_dir_upds = int(virtual_adv_direction_upds)
+
+ def _lipschitz_regularization(
+ self, discriminator, x, y, sample_weight=None, training_discriminator=True
+ ):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return compute_AdversarialLipschitzPenalty(
+ discriminator=discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ training_discriminator=training_discriminator,
+ vir_adv_dir_upds=self._vir_adv_dir_upds,
+ xi_min=XI_MIN,
+ xi_max=XI_MAX,
+ lipschitz_penalty=self._lipschitz_penalty,
+ lipschitz_penalty_strategy=self._lipschitz_penalty_strategy,
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+ )
+
+ @property
+ def virtual_adv_direction_upds(self) -> int:
+ return self._vir_adv_dir_upds
diff --git a/src/pidgan/algorithms/k3/BceGAN_GP.py b/src/pidgan/algorithms/k3/BceGAN_GP.py
new file mode 100644
index 0000000..966619f
--- /dev/null
+++ b/src/pidgan/algorithms/k3/BceGAN_GP.py
@@ -0,0 +1,83 @@
+from pidgan.algorithms.k3.BceGAN import BceGAN
+from pidgan.algorithms.k3.lipschitz_regularizations import (
+ PENALTY_STRATEGIES,
+ compute_GradientPenalty,
+)
+
+LIPSCHITZ_CONSTANT = 1.0
+
+
+class BceGAN_GP(BceGAN):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="two-sided",
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="BceGAN-GP",
+ dtype=None,
+ ):
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ from_logits=True,
+ label_smoothing=0.0,
+ injected_noise_stddev=0.0,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+
+ # Lipschitz penalty
+ assert isinstance(lipschitz_penalty, (int, float))
+ assert lipschitz_penalty > 0.0
+ self._lipschitz_penalty = float(lipschitz_penalty)
+
+ # Penalty strategy
+ assert isinstance(lipschitz_penalty_strategy, str)
+ if lipschitz_penalty_strategy not in PENALTY_STRATEGIES:
+ raise ValueError(
+ "`lipschitz_penalty_strategy` should be selected "
+ f"in {PENALTY_STRATEGIES}, instead "
+ f"'{lipschitz_penalty_strategy}' passed"
+ )
+ self._lipschitz_penalty_strategy = lipschitz_penalty_strategy
+
+ def _compute_d_loss(self, x, y, sample_weight=None, training=True, test=False):
+ d_loss = super()._compute_d_loss(x, y, sample_weight, training)
+ if not test:
+ d_loss += self._lipschitz_regularization(
+ self._discriminator,
+ x,
+ y,
+ sample_weight,
+ training_discriminator=training,
+ )
+ return d_loss
+
+ def _lipschitz_regularization(
+ self, discriminator, x, y, sample_weight=None, training_discriminator=True
+ ):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return compute_GradientPenalty(
+ discriminator=discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ training_discriminator=training_discriminator,
+ lipschitz_penalty=self._lipschitz_penalty,
+ lipschitz_penalty_strategy=self._lipschitz_penalty_strategy,
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+ )
+
+ @property
+ def lipschitz_penalty(self) -> float:
+ return self._lipschitz_penalty
+
+ @property
+ def lipschitz_penalty_strategy(self) -> str:
+ return self._lipschitz_penalty_strategy
diff --git a/src/pidgan/algorithms/k3/CramerGAN.py b/src/pidgan/algorithms/k3/CramerGAN.py
new file mode 100644
index 0000000..cbf82df
--- /dev/null
+++ b/src/pidgan/algorithms/k3/CramerGAN.py
@@ -0,0 +1,203 @@
+import keras as k
+
+from pidgan.algorithms.k3.lipschitz_regularizations import compute_CriticGradientPenalty
+from pidgan.algorithms.k3.WGAN_GP import WGAN_GP
+
+LIPSCHITZ_CONSTANT = 1.0
+
+
+class Critic:
+ def __init__(self, nn) -> None:
+ self._nn = nn
+
+ def __call__(self, input_1, input_2, training=True):
+ return k.ops.norm(
+ self._nn(input_1, training) - self._nn(input_2, training), axis=-1
+ ) - k.ops.norm(self._nn(input_1, training), axis=-1)
+
+
+class CramerGAN(WGAN_GP):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="two-sided",
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="CramerGAN",
+ dtype=None,
+ ):
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ lipschitz_penalty=lipschitz_penalty,
+ lipschitz_penalty_strategy=lipschitz_penalty_strategy,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+ self._loss_name = "Energy distance"
+
+ # Critic function
+ self._critic = Critic(lambda x, t: self._discriminator(x, training=t))
+
+ def _update_metric_states(self, x, y, sample_weight) -> None:
+ metric_states = {
+ "g_loss": self._g_loss_state.result(),
+ "d_loss": self._d_loss_state.result(),
+ }
+ if self._referee is not None:
+ metric_states.update({"r_loss": self._r_loss_state.result()})
+ if self._train_metrics is not None:
+ batch_size = k.ops.cast(k.ops.shape(x)[0] / 2, dtype="int32")
+ x_1, x_2 = k.ops.split(x[: batch_size * 2], 2, axis=0)
+ y_1 = y[:batch_size]
+ g_out = self._generator(x[: batch_size * 2], training=False)
+ g_out_1, g_out_2 = k.ops.split(g_out, 2, axis=0)
+
+ x_concat_1 = k.ops.concatenate([x_1, x_1], axis=0)
+ y_concat_1 = k.ops.concatenate([y_1, g_out_1], axis=0)
+ x_concat_2 = k.ops.concatenate([x_2, x_2], axis=0)
+ y_concat_2 = k.ops.concatenate([g_out_2, g_out_2], axis=0)
+
+ c_out = self._critic(
+ (x_concat_1, y_concat_1), (x_concat_2, y_concat_2), training=False
+ )
+ c_ref, c_gen = k.ops.split(c_out, 2, axis=0)
+ for metric in self._train_metrics:
+ if sample_weight is not None:
+ w_1, w_2 = k.ops.split(sample_weight[: batch_size * 2], 2, axis=0)
+ weights = w_1 * w_2
+ else:
+ weights = None
+ metric.update_state(y_true=c_ref, y_pred=c_gen, sample_weight=weights)
+ metric_states.update({metric.name: metric.result()})
+ return metric_states
+
+ def _prepare_trainset(
+ self, x, y, sample_weight=None, training_generator=True
+ ) -> tuple:
+ batch_size = k.ops.cast(k.ops.shape(x)[0] / 4, dtype="int32")
+ x_ref, x_gen_1, x_gen_2 = k.ops.split(x[: batch_size * 3], 3, axis=0)
+ y_ref = y[:batch_size]
+
+ x_gen_concat = k.ops.concatenate([x_gen_1, x_gen_2], axis=0)
+ y_gen = self._generator(x_gen_concat, training=training_generator)
+ y_gen_1, y_gen_2 = k.ops.split(y_gen, 2, axis=0)
+
+ if sample_weight is not None:
+ w_ref, w_gen_1, w_gen_2 = k.ops.split(
+ sample_weight[: batch_size * 3], 3, axis=0
+ )
+ else:
+ w_ref, w_gen_1, w_gen_2 = k.ops.split(
+ k.ops.ones(shape=(batch_size * 3,)), 3, axis=0
+ )
+
+ return (
+ (x_ref, y_ref, w_ref),
+ (x_gen_1, y_gen_1, w_gen_1),
+ (x_gen_2, y_gen_2, w_gen_2),
+ )
+
+ @staticmethod
+ def _standard_loss_func(
+ critic,
+ trainset_ref,
+ trainset_gen_1,
+ trainset_gen_2,
+ training_critic=False,
+ generator_loss=True,
+ ):
+ x_ref, y_ref, w_ref = trainset_ref
+ x_gen_1, y_gen_1, w_gen_1 = trainset_gen_1
+ x_gen_2, y_gen_2, w_gen_2 = trainset_gen_2
+
+ x_concat_1 = k.ops.concatenate([x_ref, x_gen_1], axis=0)
+ y_concat_1 = k.ops.concatenate([y_ref, y_gen_1], axis=0)
+ x_concat_2 = k.ops.concatenate([x_gen_2, x_gen_2], axis=0)
+ y_concat_2 = k.ops.concatenate([y_gen_2, y_gen_2], axis=0)
+
+ c_out = critic(
+ (x_concat_1, y_concat_1), (x_concat_2, y_concat_2), training=training_critic
+ )
+ c_ref, c_gen = k.ops.split(c_out, 2, axis=0)
+
+ real_loss = k.ops.sum(w_ref * w_gen_2 * c_ref) / k.ops.sum(w_ref * w_gen_2)
+ fake_loss = k.ops.sum(w_gen_1 * w_gen_2 * c_gen) / k.ops.sum(w_gen_1 * w_gen_2)
+
+ if generator_loss:
+ return real_loss - fake_loss
+ else:
+ return fake_loss - real_loss
+
+ def _compute_g_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen_1, trainset_gen_2 = self._prepare_trainset(
+ x, y, sample_weight, training_generator=training
+ )
+ return self._standard_loss_func(
+ critic=self._critic,
+ trainset_ref=trainset_ref,
+ trainset_gen_1=trainset_gen_1,
+ trainset_gen_2=trainset_gen_2,
+ training_critic=False,
+ generator_loss=True,
+ )
+
+ def _compute_d_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen_1, trainset_gen_2 = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ d_loss = self._standard_loss_func(
+ critic=self._critic,
+ trainset_ref=trainset_ref,
+ trainset_gen_1=trainset_gen_1,
+ trainset_gen_2=trainset_gen_2,
+ training_critic=training,
+ generator_loss=False,
+ )
+ if not test:
+ d_loss += self._lipschitz_regularization(
+ self._critic, x, y, sample_weight, training_critic=training
+ )
+ return d_loss
+
+ def _compute_r_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen, _ = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ x_ref, y_ref, w_ref = trainset_ref
+ x_gen, y_gen, w_gen = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ r_out = self._referee((x_concat, y_concat), training=training)
+ r_ref, r_gen = k.ops.split(r_out, 2, axis=0)
+
+ real_loss = self._referee_loss(
+ k.ops.ones_like(r_ref), r_ref, sample_weight=w_ref
+ )
+ fake_loss = self._referee_loss(
+ k.ops.zeros_like(r_gen), r_gen, sample_weight=w_gen
+ )
+ return (real_loss + fake_loss) / 2.0
+
+ def _lipschitz_regularization(
+ self, critic, x, y, sample_weight=None, training_critic=True
+ ):
+ trainset_ref, trainset_gen_1, trainset_gen_2 = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return compute_CriticGradientPenalty(
+ critic=critic,
+ trainset_ref=trainset_ref,
+ trainset_gen_1=trainset_gen_1,
+ trainset_gen_2=trainset_gen_2,
+ training_critic=training_critic,
+ lipschitz_penalty=self._lipschitz_penalty,
+ lipschitz_penalty_strategy=self._lipschitz_penalty_strategy,
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+ )
diff --git a/src/pidgan/algorithms/k3/GAN.py b/src/pidgan/algorithms/k3/GAN.py
new file mode 100644
index 0000000..853f610
--- /dev/null
+++ b/src/pidgan/algorithms/k3/GAN.py
@@ -0,0 +1,516 @@
+import warnings
+
+import keras as k
+import tensorflow as tf
+
+from pidgan.players.classifiers import Classifier
+from pidgan.players.discriminators import Discriminator
+from pidgan.players.generators import Generator
+from pidgan.utils.checks import checkMetrics, checkOptimizer
+
+MIN_LOG_VALUE = 1e-6
+MAX_LOG_VALUE = 1.0
+
+
+class GAN(k.Model):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ use_original_loss=True,
+ injected_noise_stddev=0.0,
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="GAN",
+ dtype=None,
+ ) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._loss_name = "GAN original loss"
+
+ self._metrics = list()
+ self._train_metrics = None
+ self._model_is_built = False
+
+ # Generator
+ if not isinstance(generator, Generator):
+ raise TypeError(
+ f"`generator` should be a pidgan's Generator, "
+ f"instead {type(generator)} passed"
+ )
+ self._generator = generator
+ self._g_loss_state = k.metrics.Mean(name="g_loss")
+ self._metrics += [self._g_loss_state]
+
+ # Discriminator
+ if not isinstance(discriminator, Discriminator):
+ raise TypeError(
+ f"`discriminator` should be a pidgan's Discriminator, "
+ f"instead {type(discriminator)} passed"
+ )
+ self._discriminator = discriminator
+ self._d_loss_state = k.metrics.Mean(name="d_loss")
+ self._metrics += [self._d_loss_state]
+
+ # Flag to use the original loss
+ assert isinstance(use_original_loss, bool)
+ self._use_original_loss = use_original_loss
+
+ # Referee network and loss
+ if referee is not None:
+ if not isinstance(referee, Classifier):
+ if not isinstance(referee, Discriminator):
+ raise TypeError(
+ f"`referee` should be a pidgan's Classifier "
+ f"(or Discriminator), instead "
+ f"{type(referee)} passed"
+ )
+ self._referee = referee
+ self._referee_loss = k.losses.BinaryCrossentropy()
+ self._r_loss_state = k.metrics.Mean(name="r_loss")
+ self._metrics += [self._r_loss_state]
+ else:
+ self._referee = None
+ self._referee_loss = None
+
+ # Noise standard deviation
+ assert isinstance(injected_noise_stddev, (int, float))
+ assert injected_noise_stddev >= 0.0
+ self._inj_noise_std = float(injected_noise_stddev)
+
+ # Feature matching penalty
+ assert isinstance(feature_matching_penalty, (int, float))
+ assert feature_matching_penalty >= 0.0
+ self._feature_matching_penalty = float(feature_matching_penalty)
+
+ def build(self, input_shape) -> None:
+ super().build(input_shape=input_shape)
+ self._train_metrics = checkMetrics(self._train_metrics)
+ self._model_is_built = True
+
+ def compile(
+ self,
+ metrics=None,
+ generator_optimizer="rmsprop",
+ discriminator_optimizer="rmsprop",
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=None,
+ referee_upds_per_batch=None,
+ ) -> None:
+ super().compile(weighted_metrics=[])
+
+ # Metrics
+ if not self._model_is_built:
+ self._train_metrics = metrics
+ self.build(input_shape=[])
+ if self._train_metrics is not None:
+ self._metrics += self._train_metrics
+ else:
+ if metrics is not None:
+ warnings.warn(
+ "The `metrics` argument is ignored when the model is "
+ "built before to be compiled. Consider to move the first model "
+ "calling after the `compile()` method to fix this issue.",
+ category=UserWarning,
+ stacklevel=1,
+ )
+
+ # Gen/Disc optimizers
+ self._g_opt = checkOptimizer(generator_optimizer)
+ self._d_opt = checkOptimizer(discriminator_optimizer)
+
+ # Generator updates per batch
+ assert isinstance(generator_upds_per_batch, (int, float))
+ assert generator_upds_per_batch >= 1
+ self._g_upds_per_batch = int(generator_upds_per_batch)
+
+ # Discriminator updates per batch
+ assert isinstance(discriminator_upds_per_batch, (int, float))
+ assert discriminator_upds_per_batch >= 1
+ self._d_upds_per_batch = int(discriminator_upds_per_batch)
+
+ # Referee settings
+ if self._referee is not None:
+ referee_optimizer = (
+ referee_optimizer if referee_optimizer is not None else "rmsprop"
+ )
+ self._r_opt = checkOptimizer(referee_optimizer)
+ if referee_upds_per_batch is not None:
+ assert isinstance(referee_upds_per_batch, (int, float))
+ assert referee_upds_per_batch >= 1
+ else:
+ referee_upds_per_batch = 1
+ self._r_upds_per_batch = int(referee_upds_per_batch)
+ else:
+ self._r_opt = None
+ self._r_upds_per_batch = None
+
+ def call(self, x, y=None) -> tuple:
+ g_out = self._generator(x)
+ d_out_gen = self._discriminator((x, g_out))
+ if y is None:
+ if self._referee is not None:
+ r_out_gen = self._referee((x, g_out))
+ return g_out, d_out_gen, r_out_gen
+ else:
+ return g_out, d_out_gen
+ else:
+ d_out_ref = self._discriminator((x, y))
+ if self._referee is not None:
+ r_out_gen = self._referee((x, g_out))
+ r_out_ref = self._referee((x, y))
+ return g_out, (d_out_gen, d_out_ref), (r_out_gen, r_out_ref)
+ else:
+ return g_out, (d_out_gen, d_out_ref)
+
+ def summary(self, **kwargs) -> None:
+ print("_" * 65)
+ self._generator.summary(**kwargs)
+ self._discriminator.summary(**kwargs)
+ if self._referee is not None:
+ self._referee.summary(**kwargs)
+
+ def train_step(self, *args, **kwargs):
+ if k.backend.backend() == "tensorflow":
+ return self._tf_train_step(*args, **kwargs)
+ elif k.backend.backend() == "torch":
+ raise NotImplementedError(
+ "`train_step()` not implemented for the PyTorch backend"
+ )
+ elif k.backend.backend() == "jax":
+ raise NotImplementedError(
+ "`train_step()` not implemented for the Jax backend"
+ )
+
+ @staticmethod
+ def _unpack_data(data) -> tuple:
+ if len(data) == 3:
+ x, y, sample_weight = data
+ else:
+ x, y = data
+ sample_weight = None
+ return x, y, sample_weight
+
+ def _tf_train_step(self, data) -> dict:
+ x, y, sample_weight = self._unpack_data(data)
+
+ for _ in range(self._d_upds_per_batch):
+ self._tf_d_train_step(x, y, sample_weight)
+ for _ in range(self._g_upds_per_batch):
+ self._tf_g_train_step(x, y, sample_weight)
+ if self._referee is not None:
+ for _ in range(self._r_upds_per_batch):
+ self._tf_r_train_step(x, y, sample_weight)
+
+ return self._update_metric_states(x, y, sample_weight)
+
+ def _tf_g_train_step(self, x, y, sample_weight=None) -> None:
+ with tf.GradientTape() as tape:
+ loss = self._compute_g_loss(x, y, sample_weight, training=True, test=False)
+
+ trainable_vars = self._generator.trainable_weights
+ gradients = tape.gradient(loss, trainable_vars)
+ self._g_opt.apply_gradients(zip(gradients, trainable_vars))
+
+ threshold = self._compute_threshold(self._discriminator, x, y, sample_weight)
+ self._g_loss_state.update_state(loss + threshold)
+
+ def _tf_d_train_step(self, x, y, sample_weight=None) -> None:
+ with tf.GradientTape() as tape:
+ loss = self._compute_d_loss(x, y, sample_weight, training=True, test=False)
+
+ trainable_vars = self._discriminator.trainable_weights
+ gradients = tape.gradient(loss, trainable_vars)
+ self._d_opt.apply_gradients(zip(gradients, trainable_vars))
+
+ threshold = self._compute_threshold(self._discriminator, x, y, sample_weight)
+ self._d_loss_state.update_state(loss - threshold)
+
+ def _tf_r_train_step(self, x, y, sample_weight=None) -> None:
+ with tf.GradientTape() as tape:
+ loss = self._compute_r_loss(x, y, sample_weight, training=True, test=False)
+
+ trainable_vars = self._referee.trainable_weights
+ gradients = tape.gradient(loss, trainable_vars)
+ self._r_opt.apply_gradients(zip(gradients, trainable_vars))
+
+ self._r_loss_state.update_state(loss)
+
+ def _update_metric_states(self, x, y, sample_weight) -> None:
+ metric_states = {
+ "g_loss": self._g_loss_state.result(),
+ "d_loss": self._d_loss_state.result(),
+ }
+ if self._referee is not None:
+ metric_states.update({"r_loss": self._r_loss_state.result()})
+ if self._train_metrics is not None:
+ g_out = self._generator(x, training=False)
+ x_concat = k.ops.concatenate([x, x], axis=0)
+ y_concat = k.ops.concatenate([y, g_out], axis=0)
+ d_out = self._discriminator((x_concat, y_concat), training=False)
+ d_ref, d_gen = k.ops.split(d_out, 2, axis=0)
+ for metric in self._train_metrics:
+ metric.update_state(
+ y_true=d_ref, y_pred=d_gen, sample_weight=sample_weight
+ )
+ metric_states.update({metric.name: metric.result()})
+ return metric_states
+
+ def _prepare_trainset(
+ self, x, y, sample_weight=None, training_generator=True
+ ) -> tuple:
+ batch_size = k.ops.cast(k.ops.shape(x)[0] / 2, dtype="int32")
+ x_ref, x_gen = k.ops.split(x[: batch_size * 2], 2, axis=0)
+ y_ref = y[:batch_size]
+ y_gen = self._generator(x_gen, training=training_generator)
+
+ if sample_weight is not None:
+ w_ref, w_gen = k.ops.split(sample_weight[: batch_size * 2], 2, axis=0)
+ else:
+ w_ref, w_gen = k.ops.split(k.ops.ones(shape=(batch_size * 2,)), 2, axis=0)
+
+ return (x_ref, y_ref, w_ref), (x_gen, y_gen, w_gen)
+
+ @staticmethod
+ def _standard_loss_func(
+ discriminator,
+ trainset_ref,
+ trainset_gen,
+ inj_noise_std=0.0,
+ training_discriminator=False,
+ original_loss=True,
+ generator_loss=True,
+ ) -> tf.Tensor:
+ x_ref, y_ref, w_ref = trainset_ref
+ x_gen, y_gen, w_gen = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ if inj_noise_std > 0.0:
+ rnd_noise = k.random.normal(
+ shape=k.ops.shape(y_concat),
+ mean=0.0,
+ stddev=inj_noise_std,
+ dtype=y_concat.dtype,
+ )
+ y_concat += rnd_noise
+
+ d_out = discriminator((x_concat, y_concat), training=training_discriminator)
+ d_ref, d_gen = k.ops.split(d_out, 2, axis=0)
+
+ real_loss = k.ops.sum(
+ w_ref[:, None] * k.ops.log(k.ops.clip(d_ref, MIN_LOG_VALUE, MAX_LOG_VALUE))
+ ) / k.ops.sum(w_ref)
+ if original_loss:
+ fake_loss = k.ops.sum(
+ w_gen[:, None]
+ * k.ops.log(k.ops.clip(1.0 - d_gen, MIN_LOG_VALUE, MAX_LOG_VALUE))
+ ) / k.ops.sum(w_gen)
+ else:
+ fake_loss = k.ops.sum(
+ -w_gen[:, None]
+ * k.ops.log(k.ops.clip(d_gen, MIN_LOG_VALUE, MAX_LOG_VALUE))
+ )
+
+ if generator_loss:
+ return tf.stop_gradient(real_loss) + fake_loss
+ else:
+ return -(real_loss + fake_loss)
+
+ def _compute_g_loss(
+ self, x, y, sample_weight=None, training=True, test=False
+ ) -> tf.Tensor:
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=training
+ )
+ g_loss = self._standard_loss_func(
+ discriminator=self._discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ inj_noise_std=self._inj_noise_std,
+ training_discriminator=False,
+ original_loss=self._use_original_loss,
+ generator_loss=True,
+ )
+ if not test:
+ g_loss += self._compute_feature_matching(
+ x, y, sample_weight, training_generator=training
+ )
+ return g_loss
+
+ def _compute_d_loss(
+ self, x, y, sample_weight=None, training=True, test=False
+ ) -> tf.Tensor:
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return self._standard_loss_func(
+ discriminator=self._discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ inj_noise_std=self._inj_noise_std,
+ training_discriminator=training,
+ original_loss=True,
+ generator_loss=False,
+ )
+
+ def _compute_r_loss(
+ self, x, y, sample_weight=None, training=True, test=False
+ ) -> tf.Tensor:
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ x_ref, y_ref, w_ref = trainset_ref
+ x_gen, y_gen, w_gen = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ r_out = self._referee((x_concat, y_concat), training=training)
+ r_ref, r_gen = k.ops.split(r_out, 2, axis=0)
+
+ real_loss = self._referee_loss(
+ k.ops.ones_like(r_ref), r_ref, sample_weight=w_ref
+ )
+ fake_loss = self._referee_loss(
+ k.ops.zeros_like(r_gen), r_gen, sample_weight=w_gen
+ )
+ return (real_loss + fake_loss) / 2.0
+
+ def _compute_feature_matching(
+ self, x, y, sample_weight=None, training_generator=True
+ ) -> tf.Tensor:
+ if self._feature_matching_penalty > 0.0:
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=training_generator
+ )
+ x_ref, y_ref, _ = trainset_ref
+ x_gen, y_gen, _ = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ if self._inj_noise_std is not None:
+ if self._inj_noise_std > 0.0:
+ rnd_noise = k.random.normal(
+ shape=(k.ops.shape(y_ref)[0] * 2, k.ops.shape(y_ref)[1]),
+ mean=0.0,
+ stddev=self._inj_noise_std,
+ dtype=y_ref.dtype,
+ )
+ y_concat += rnd_noise
+
+ d_feat_out = self._discriminator.hidden_feature((x_concat, y_concat))
+ d_feat_ref, d_feat_gen = k.ops.split(d_feat_out, 2, axis=0)
+
+ feat_match_term = k.ops.norm(d_feat_ref - d_feat_gen, axis=-1) ** 2
+ return self._feature_matching_penalty * k.ops.mean(feat_match_term)
+ else:
+ return 0.0
+
+ def _prepare_trainset_threshold(self, x, y, sample_weight=None) -> tuple:
+ batch_size = k.ops.cast(k.ops.shape(x)[0] / 2, dtype="int32")
+ x_ref_1, x_ref_2 = k.ops.split(x[: batch_size * 2], 2, axis=0)
+ y_ref_1, y_ref_2 = k.ops.split(y[: batch_size * 2], 2, axis=0)
+
+ if sample_weight is not None:
+ w_ref_1, w_ref_2 = k.ops.split(sample_weight[: batch_size * 2], 2, axis=0)
+ else:
+ w_ref_1, w_ref_2 = k.ops.split(
+ k.ops.ones(shape=(batch_size * 2,)), 2, axis=0
+ )
+
+ return (x_ref_1, y_ref_1, w_ref_1), (x_ref_2, y_ref_2, w_ref_2)
+
+ def _compute_threshold(self, discriminator, x, y, sample_weight=None) -> tf.Tensor:
+ trainset_ref_1, trainset_ref_2 = self._prepare_trainset_threshold(
+ x, y, sample_weight
+ )
+ return self._standard_loss_func(
+ discriminator=discriminator,
+ trainset_ref=trainset_ref_1,
+ trainset_gen=trainset_ref_2,
+ inj_noise_std=0.0,
+ training_discriminator=False,
+ original_loss=True,
+ generator_loss=False,
+ )
+
+ def test_step(self, data) -> dict:
+ x, y, sample_weight = self._unpack_data(data)
+
+ threshold = self._compute_threshold(self._discriminator, x, y, sample_weight)
+
+ g_loss = self._compute_g_loss(x, y, sample_weight, training=False, test=True)
+ self._g_loss_state.update_state(g_loss + threshold)
+
+ d_loss = self._compute_d_loss(x, y, sample_weight, training=False, test=True)
+ self._d_loss_state.update_state(d_loss - threshold)
+
+ if self._referee is not None:
+ r_loss = self._compute_r_loss(
+ x, y, sample_weight, training=False, test=True
+ )
+ self._r_loss_state.update_state(r_loss)
+
+ return self._update_metric_states(x, y, sample_weight)
+
+ def generate(self, x, seed=None) -> tf.Tensor:
+ return self._generator.generate(x, seed=seed)
+
+ @property
+ def loss_name(self) -> str:
+ return self._loss_name
+
+ @property
+ def generator(self) -> Generator:
+ return self._generator
+
+ @property
+ def discriminator(self) -> Discriminator:
+ return self._discriminator
+
+ @property
+ def use_original_loss(self) -> bool:
+ return self._use_original_loss
+
+ @property
+ def injected_noise_stddev(self) -> float:
+ return self._inj_noise_std
+
+ @property
+ def feature_matching_penalty(self) -> float:
+ return self._feature_matching_penalty
+
+ @property
+ def referee(self): # TODO: add Union[None, Discriminator]
+ return self._referee
+
+ @property
+ def metrics(self) -> list:
+ return self._metrics
+
+ @property
+ def generator_optimizer(self) -> k.optimizers.Optimizer:
+ return self._g_opt
+
+ @property
+ def discriminator_optimizer(self) -> k.optimizers.Optimizer:
+ return self._d_opt
+
+ @property
+ def generator_upds_per_batch(self) -> int:
+ return self._g_upds_per_batch
+
+ @property
+ def discriminator_upds_per_batch(self) -> int:
+ return self._d_upds_per_batch
+
+ @property
+ def referee_optimizer(self): # TODO: add Union[None, Optimizer]
+ return self._r_opt
+
+ @property
+ def referee_upds_per_batch(self): # TODO: add Union[None, int]
+ return self._r_upds_per_batch
diff --git a/src/pidgan/algorithms/k3/LSGAN.py b/src/pidgan/algorithms/k3/LSGAN.py
new file mode 100644
index 0000000..79bf3bf
--- /dev/null
+++ b/src/pidgan/algorithms/k3/LSGAN.py
@@ -0,0 +1,110 @@
+import keras as k
+
+from pidgan.algorithms.k3.GAN import GAN
+
+
+class LSGAN(GAN):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ minimize_pearson_chi2=False,
+ injected_noise_stddev=0.0,
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="LSGAN",
+ dtype=None,
+ ) -> None:
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ injected_noise_stddev=injected_noise_stddev,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+ self._loss_name = "Least squares loss"
+ self._use_original_loss = None
+
+ # Flag to minimize the Pearson chi2 divergence
+ assert isinstance(minimize_pearson_chi2, bool)
+ self._minimize_pearson_chi2 = minimize_pearson_chi2
+
+ @staticmethod
+ def _standard_loss_func(
+ discriminator,
+ trainset_ref,
+ trainset_gen,
+ a_param,
+ b_param,
+ inj_noise_std=0.0,
+ training_discriminator=False,
+ generator_loss=True,
+ ):
+ x_ref, y_ref, w_ref = trainset_ref
+ x_gen, y_gen, w_gen = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ if inj_noise_std > 0.0:
+ rnd_noise = k.random.normal(
+ shape=(k.ops.shape(y_ref)[0] * 2, k.ops.shape(y_ref)[1]),
+ mean=0.0,
+ stddev=inj_noise_std,
+ dtype=y_ref.dtype,
+ )
+ y_concat += rnd_noise
+
+ d_out = discriminator((x_concat, y_concat), training=training_discriminator)
+ d_ref, d_gen = k.ops.split(d_out, 2, axis=0)
+
+ real_loss = k.ops.sum(w_ref[:, None] * (d_ref - b_param) ** 2) / k.ops.sum(
+ w_ref
+ )
+ fake_loss = k.ops.sum(w_gen[:, None] * (d_gen - a_param) ** 2) / k.ops.sum(
+ w_gen
+ )
+
+ if generator_loss:
+ return (k.ops.stop_gradient(real_loss) + fake_loss) / 2.0
+ else:
+ return (real_loss + fake_loss) / 2.0
+
+ def _compute_g_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=training
+ )
+ return self._standard_loss_func(
+ discriminator=self._discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ a_param=0.0 if self._minimize_pearson_chi2 else 1.0,
+ b_param=0.0 if self._minimize_pearson_chi2 else 1.0,
+ inj_noise_std=self._inj_noise_std,
+ training_discriminator=False,
+ generator_loss=True,
+ )
+
+ def _compute_d_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return self._standard_loss_func(
+ discriminator=self._discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ a_param=-1.0 if self._minimize_pearson_chi2 else 0.0,
+ b_param=1.0,
+ inj_noise_std=self._inj_noise_std,
+ training_discriminator=training,
+ generator_loss=False,
+ )
+
+ def _compute_threshold(self, discriminator, x, y, sample_weight=None):
+ return 0.0
+
+ @property
+ def minimize_pearson_chi2(self) -> bool:
+ return self._minimize_pearson_chi2
diff --git a/src/pidgan/algorithms/k3/WGAN.py b/src/pidgan/algorithms/k3/WGAN.py
new file mode 100644
index 0000000..6748f88
--- /dev/null
+++ b/src/pidgan/algorithms/k3/WGAN.py
@@ -0,0 +1,93 @@
+import keras as k
+
+from pidgan.algorithms.k3.GAN import GAN
+
+
+class WGAN(GAN):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ clip_param=0.01,
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="WGAN",
+ dtype=None,
+ ):
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+ self._loss_name = "Wasserstein distance"
+ self._use_original_loss = None
+ self._inj_noise_std = None
+
+ # Clipping parameter
+ assert isinstance(clip_param, (int, float))
+ assert clip_param > 0.0
+ self._clip_param = float(clip_param)
+
+ def _tf_d_train_step(self, x, y, sample_weight=None) -> None:
+ super()._tf_d_train_step(x, y, sample_weight)
+ for w in self._discriminator.trainable_weights:
+ w = k.ops.clip(w, -self._clip_param, self._clip_param)
+
+ @staticmethod
+ def _standard_loss_func(
+ discriminator,
+ trainset_ref,
+ trainset_gen,
+ training_discriminator=False,
+ generator_loss=True,
+ ):
+ x_ref, y_ref, w_ref = trainset_ref
+ x_gen, y_gen, w_gen = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ d_out = discriminator((x_concat, y_concat), training=training_discriminator)
+ d_ref, d_gen = k.ops.split(d_out, 2, axis=0)
+
+ real_loss = k.ops.sum(w_ref[:, None] * d_ref) / k.ops.sum(w_ref)
+ fake_loss = k.ops.sum(w_gen[:, None] * d_gen) / k.ops.sum(w_gen)
+
+ if generator_loss:
+ return k.ops.stop_gradient(real_loss) - fake_loss
+ else:
+ return fake_loss - real_loss
+
+ def _compute_g_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=training
+ )
+ return self._standard_loss_func(
+ discriminator=self._discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ training_discriminator=False,
+ generator_loss=True,
+ )
+
+ def _compute_d_loss(self, x, y, sample_weight=None, training=True, test=False):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return self._standard_loss_func(
+ discriminator=self._discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ training_discriminator=training,
+ generator_loss=False,
+ )
+
+ def _compute_threshold(self, discriminator, x, y, sample_weight=None):
+ return 0.0
+
+ @property
+ def clip_param(self) -> float:
+ return self._clip_param
diff --git a/src/pidgan/algorithms/k3/WGAN_ALP.py b/src/pidgan/algorithms/k3/WGAN_ALP.py
new file mode 100644
index 0000000..e9784d3
--- /dev/null
+++ b/src/pidgan/algorithms/k3/WGAN_ALP.py
@@ -0,0 +1,81 @@
+from pidgan.algorithms.k3.lipschitz_regularizations import (
+ compute_AdversarialLipschitzPenalty,
+)
+from pidgan.algorithms.k3.WGAN_GP import WGAN_GP
+
+LIPSCHITZ_CONSTANT = 1.0
+XI_MIN = 0.8
+XI_MAX = 1.2
+
+
+class WGAN_ALP(WGAN_GP):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="one-sided",
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="WGAN-ALP",
+ dtype=None,
+ ):
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ lipschitz_penalty=lipschitz_penalty,
+ lipschitz_penalty_strategy=lipschitz_penalty_strategy,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+
+ def compile(
+ self,
+ metrics=None,
+ generator_optimizer="rmsprop",
+ discriminator_optimizer="rmsprop",
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ virtual_adv_direction_upds=1,
+ referee_optimizer=None,
+ referee_upds_per_batch=None,
+ ) -> None:
+ super().compile(
+ metrics=metrics,
+ generator_optimizer=generator_optimizer,
+ discriminator_optimizer=discriminator_optimizer,
+ generator_upds_per_batch=generator_upds_per_batch,
+ discriminator_upds_per_batch=discriminator_upds_per_batch,
+ referee_optimizer=referee_optimizer,
+ referee_upds_per_batch=referee_upds_per_batch,
+ )
+
+ # Virtual adversarial direction updates
+ assert isinstance(virtual_adv_direction_upds, (int, float))
+ assert virtual_adv_direction_upds > 0
+ self._vir_adv_dir_upds = int(virtual_adv_direction_upds)
+
+ def _lipschitz_regularization(
+ self, discriminator, x, y, sample_weight=None, training_discriminator=True
+ ):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return compute_AdversarialLipschitzPenalty(
+ discriminator=discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ training_discriminator=training_discriminator,
+ vir_adv_dir_upds=self._vir_adv_dir_upds,
+ xi_min=XI_MIN,
+ xi_max=XI_MAX,
+ lipschitz_penalty=self._lipschitz_penalty,
+ lipschitz_penalty_strategy=self._lipschitz_penalty_strategy,
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+ )
+
+ @property
+ def virtual_adv_direction_upds(self) -> int:
+ return self._vir_adv_dir_upds
diff --git a/src/pidgan/algorithms/k3/WGAN_GP.py b/src/pidgan/algorithms/k3/WGAN_GP.py
new file mode 100644
index 0000000..e893a6c
--- /dev/null
+++ b/src/pidgan/algorithms/k3/WGAN_GP.py
@@ -0,0 +1,84 @@
+from pidgan.algorithms.k3.lipschitz_regularizations import (
+ PENALTY_STRATEGIES,
+ compute_GradientPenalty,
+)
+from pidgan.algorithms.k3.WGAN import WGAN
+
+LIPSCHITZ_CONSTANT = 1.0
+
+
+class WGAN_GP(WGAN):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="two-sided",
+ feature_matching_penalty=0.0,
+ referee=None,
+ name="WGAN-GP",
+ dtype=None,
+ ):
+ super().__init__(
+ generator=generator,
+ discriminator=discriminator,
+ feature_matching_penalty=feature_matching_penalty,
+ referee=referee,
+ name=name,
+ dtype=dtype,
+ )
+ self._clip_param = None
+
+ # Lipschitz penalty
+ assert isinstance(lipschitz_penalty, (int, float))
+ assert lipschitz_penalty > 0.0
+ self._lipschitz_penalty = float(lipschitz_penalty)
+
+ # Penalty strategy
+ assert isinstance(lipschitz_penalty_strategy, str)
+ if lipschitz_penalty_strategy not in PENALTY_STRATEGIES:
+ raise ValueError(
+ "`lipschitz_penalty_strategy` should be selected "
+ f"in {PENALTY_STRATEGIES}, instead "
+ f"'{lipschitz_penalty_strategy}' passed"
+ )
+ self._lipschitz_penalty_strategy = lipschitz_penalty_strategy
+
+ def _tf_d_train_step(self, x, y, sample_weight=None) -> None:
+ super(WGAN, self)._tf_d_train_step(x, y, sample_weight)
+
+ def _compute_d_loss(self, x, y, sample_weight=None, training=True, test=False):
+ d_loss = super()._compute_d_loss(x, y, sample_weight, training)
+ if not test:
+ d_loss += self._lipschitz_regularization(
+ self._discriminator,
+ x,
+ y,
+ sample_weight,
+ training_discriminator=training,
+ )
+ return d_loss
+
+ def _lipschitz_regularization(
+ self, discriminator, x, y, sample_weight=None, training_discriminator=True
+ ):
+ trainset_ref, trainset_gen = self._prepare_trainset(
+ x, y, sample_weight, training_generator=False
+ )
+ return compute_GradientPenalty(
+ discriminator=discriminator,
+ trainset_ref=trainset_ref,
+ trainset_gen=trainset_gen,
+ training_discriminator=training_discriminator,
+ lipschitz_penalty=self._lipschitz_penalty,
+ lipschitz_penalty_strategy=self._lipschitz_penalty_strategy,
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+ )
+
+ @property
+ def lipschitz_penalty(self) -> float:
+ return self._lipschitz_penalty
+
+ @property
+ def lipschitz_penalty_strategy(self) -> str:
+ return self._lipschitz_penalty_strategy
diff --git a/src/pidgan/algorithms/k3/__init__.py b/src/pidgan/algorithms/k3/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/algorithms/k3/lipschitz_regularizations.py b/src/pidgan/algorithms/k3/lipschitz_regularizations.py
new file mode 100644
index 0000000..ad0fded
--- /dev/null
+++ b/src/pidgan/algorithms/k3/lipschitz_regularizations.py
@@ -0,0 +1,213 @@
+import keras as k
+import tensorflow as tf
+
+PENALTY_STRATEGIES = ["two-sided", "one-sided"]
+LIPSCHITZ_CONSTANT = 1.0
+MIN_STABLE_VALUE = 1e-6
+
+
+def compute_GradientPenalty(
+ discriminator,
+ trainset_ref,
+ trainset_gen,
+ training_discriminator=True,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="two-sided",
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+):
+ x_ref, y_ref, _ = trainset_ref
+ x_gen, y_gen, _ = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+
+ if k.backend.backend() == "tensorflow":
+ with tf.GradientTape() as tape:
+ # Compute interpolated points
+ eps = k.random.uniform(
+ shape=(k.ops.shape(y_ref)[0],),
+ minval=0.0,
+ maxval=1.0,
+ dtype=y_ref.dtype,
+ )[:, None]
+ x_hat = k.ops.clip(
+ x_gen + eps * (x_ref - x_gen),
+ x_min=k.ops.min(x_concat, axis=0),
+ x_max=k.ops.max(x_concat, axis=0),
+ )
+ y_hat = k.ops.clip(
+ y_gen + eps * (y_ref - y_gen),
+ x_min=k.ops.min(y_concat, axis=0),
+ x_max=k.ops.max(y_concat, axis=0),
+ )
+ d_in_hat = k.ops.concatenate((x_hat, y_hat), axis=-1)
+ tape.watch(d_in_hat)
+
+ # Value of the discriminator on interpolated points
+ x_hat = d_in_hat[:, : k.ops.shape(x_hat)[1]]
+ y_hat = d_in_hat[:, k.ops.shape(x_hat)[1] :]
+ d_out_hat = discriminator((x_hat, y_hat), training=training_discriminator)
+ grad = tape.gradient(d_out_hat, d_in_hat)
+ norm = k.ops.norm(grad, axis=-1)
+
+ elif k.backend.backend() == "torch":
+ raise NotImplementedError(
+ '"compute_GradientPenalty()" not implemented for the PyTorch backend'
+ )
+ elif k.backend.backend() == "jax":
+ raise NotImplementedError(
+ '"compute_GradientPenalty()" not implemented for the Jax backend'
+ )
+
+ if lipschitz_penalty_strategy == "two-sided":
+ gp_term = (norm - lipschitz_constant) ** 2
+ else:
+ gp_term = (k.ops.maximum(0.0, norm - lipschitz_constant)) ** 2
+ return lipschitz_penalty * k.ops.mean(gp_term)
+
+
+def compute_CriticGradientPenalty(
+ critic,
+ trainset_ref,
+ trainset_gen_1,
+ trainset_gen_2,
+ training_critic=True,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="two-sided",
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+):
+ x_ref, y_ref, _ = trainset_ref
+ x_gen_1, y_gen_1, _ = trainset_gen_1
+ x_gen_2, y_gen_2, _ = trainset_gen_2
+
+ x_concat = k.ops.concatenate([x_ref, x_gen_1], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen_1], axis=0)
+
+ if k.backend.backend() == "tensorflow":
+ with tf.GradientTape() as tape:
+ # Compute interpolated points
+ eps = k.random.uniform(
+ shape=(k.ops.shape(y_ref)[0],),
+ minval=0.0,
+ maxval=1.0,
+ dtype=y_ref.dtype,
+ )[:, None]
+ x_hat = k.ops.clip(
+ x_gen_1 + eps * (x_ref - x_gen_1),
+ x_min=k.ops.min(x_concat, axis=0),
+ x_max=k.ops.max(x_concat, axis=0),
+ )
+ y_hat = k.ops.clip(
+ y_gen_1 + eps * (y_ref - y_gen_1),
+ x_min=k.ops.min(y_concat, axis=0),
+ x_max=k.ops.max(y_concat, axis=0),
+ )
+ c_in_hat = k.ops.concatenate((x_hat, y_hat), axis=-1)
+ tape.watch(c_in_hat)
+
+ # Value of the critic on interpolated points
+ x_hat = c_in_hat[:, : k.ops.shape(x_hat)[1]]
+ y_hat = c_in_hat[:, k.ops.shape(x_hat)[1] :]
+ c_out_hat = critic(
+ (x_hat, y_hat), (x_gen_2, y_gen_2), training=training_critic
+ )
+ grad = tape.gradient(c_out_hat, c_in_hat)
+ norm = k.ops.norm(grad, axis=-1)
+
+ elif k.backend.backend() == "torch":
+ raise NotImplementedError(
+ '"compute_CriticGradientPenalty()" not implemented for the PyTorch backend'
+ )
+ elif k.backend.backend() == "jax":
+ raise NotImplementedError(
+ '"compute_CriticGradientPenalty()" not implemented for the Jax backend'
+ )
+
+ if lipschitz_penalty_strategy == "two-sided":
+ gp_term = (norm - lipschitz_constant) ** 2
+ else:
+ gp_term = (k.ops.maximum(0.0, norm - lipschitz_constant)) ** 2
+ return lipschitz_penalty * k.ops.mean(gp_term)
+
+
+def compute_AdversarialLipschitzPenalty(
+ discriminator,
+ trainset_ref,
+ trainset_gen,
+ training_discriminator=True,
+ vir_adv_dir_upds=1,
+ xi_min=0.1,
+ xi_max=10.0,
+ lipschitz_penalty=1.0,
+ lipschitz_penalty_strategy="one-sided",
+ lipschitz_constant=LIPSCHITZ_CONSTANT,
+):
+ x_ref, y_ref, _ = trainset_ref
+ x_gen, y_gen, _ = trainset_gen
+
+ x_concat = k.ops.concatenate([x_ref, x_gen], axis=0)
+ y_concat = k.ops.concatenate([y_ref, y_gen], axis=0)
+ d_out = discriminator((x_concat, y_concat), training=training_discriminator)
+
+ # Initial virtual adversarial direction
+ adv_dir = k.random.uniform(
+ shape=k.ops.shape(y_concat), minval=-1.0, maxval=1.0, dtype=y_concat.dtype
+ )
+ adv_dir /= k.ops.norm(adv_dir, axis=-1, keepdims=True)
+
+ if k.backend.backend() == "tensorflow":
+ for _ in range(vir_adv_dir_upds):
+ with tf.GradientTape() as tape:
+ tape.watch(adv_dir)
+ xi = k.ops.std(y_concat, axis=0, keepdims=True)
+ y_hat = k.ops.clip(
+ y_concat + xi * adv_dir,
+ x_min=k.ops.min(y_concat, axis=0),
+ x_max=k.ops.max(y_concat, axis=0),
+ )
+ d_out_hat = discriminator(
+ (x_concat, y_hat), training=training_discriminator
+ )
+ d_diff = k.ops.mean(k.ops.abs(d_out - d_out_hat))
+ grad = tape.gradient(d_diff, adv_dir)
+ adv_dir = grad / k.ops.maximum(
+ k.ops.norm(grad, axis=-1, keepdims=True), MIN_STABLE_VALUE
+ )
+
+ elif k.backend.backend() == "torch":
+ raise NotImplementedError(
+ '"compute_AdversarialLipschitzPenalty()" not '
+ "implemented for the PyTorch backend"
+ )
+ elif k.backend.backend() == "jax":
+ raise NotImplementedError(
+ '"compute_AdversarialLipschitzPenalty()" not '
+ "implemented for the Jax backend"
+ )
+
+ # Virtual adversarial direction
+ xi = k.random.uniform(
+ shape=(k.ops.shape(y_concat)[0],),
+ minval=xi_min,
+ maxval=xi_max,
+ dtype=y_concat.dtype,
+ )
+ xi = k.ops.tile(xi[:, None], (1, k.ops.shape(y_concat)[1]))
+ y_hat = k.ops.clip(
+ y_concat + xi * adv_dir,
+ x_min=k.ops.min(y_concat, axis=0),
+ x_max=k.ops.max(y_concat, axis=0),
+ )
+ d_out_hat = discriminator((x_concat, y_hat), training=training_discriminator)
+
+ d_diff = k.ops.abs(d_out - d_out_hat)
+ y_diff = k.ops.norm(y_concat - y_hat, axis=-1, keepdims=True)
+ d_diff_stable = d_diff[y_diff > MIN_STABLE_VALUE]
+ y_diff_stable = y_diff[y_diff > MIN_STABLE_VALUE]
+ K = d_diff_stable / y_diff_stable # lipschitz constant
+
+ if lipschitz_penalty_strategy == "two-sided":
+ alp_term = k.ops.abs(K - lipschitz_constant)
+ else:
+ alp_term = k.ops.maximum(0.0, K - lipschitz_constant)
+ return lipschitz_penalty * k.ops.mean(alp_term) ** 2
diff --git a/src/pidgan/callbacks/schedulers/__init__.py b/src/pidgan/callbacks/schedulers/__init__.py
index 3a493c4..075ec9f 100644
--- a/src/pidgan/callbacks/schedulers/__init__.py
+++ b/src/pidgan/callbacks/schedulers/__init__.py
@@ -1,5 +1,18 @@
-from .LearnRateCosineDecay import LearnRateCosineDecay
-from .LearnRateExpDecay import LearnRateExpDecay
-from .LearnRateInvTimeDecay import LearnRateInvTimeDecay
-from .LearnRatePiecewiseConstDecay import LearnRatePiecewiseConstDecay
-from .LearnRatePolynomialDecay import LearnRatePolynomialDecay
+import keras as k
+
+v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+
+if v_major == 3 and v_minor >= 0:
+ from .k3.LearnRateBaseScheduler import LearnRateBaseScheduler
+ from .k3.LearnRateCosineDecay import LearnRateCosineDecay
+ from .k3.LearnRateExpDecay import LearnRateExpDecay
+ from .k3.LearnRateInvTimeDecay import LearnRateInvTimeDecay
+ from .k3.LearnRatePiecewiseConstDecay import LearnRatePiecewiseConstDecay
+ from .k3.LearnRatePolynomialDecay import LearnRatePolynomialDecay
+else:
+ from .k2.LearnRateBaseScheduler import LearnRateBaseScheduler
+ from .k2.LearnRateCosineDecay import LearnRateCosineDecay
+ from .k2.LearnRateExpDecay import LearnRateExpDecay
+ from .k2.LearnRateInvTimeDecay import LearnRateInvTimeDecay
+ from .k2.LearnRatePiecewiseConstDecay import LearnRatePiecewiseConstDecay
+ from .k2.LearnRatePolynomialDecay import LearnRatePolynomialDecay
diff --git a/src/pidgan/callbacks/schedulers/LearnRateBaseScheduler.py b/src/pidgan/callbacks/schedulers/k2/LearnRateBaseScheduler.py
similarity index 77%
rename from src/pidgan/callbacks/schedulers/LearnRateBaseScheduler.py
rename to src/pidgan/callbacks/schedulers/k2/LearnRateBaseScheduler.py
index 606fcca..ee6ff53 100644
--- a/src/pidgan/callbacks/schedulers/LearnRateBaseScheduler.py
+++ b/src/pidgan/callbacks/schedulers/k2/LearnRateBaseScheduler.py
@@ -1,16 +1,16 @@
+import keras as k
import tensorflow as tf
-from tensorflow import keras
-K = keras.backend
+K = k.backend
-class LearnRateBaseScheduler(keras.callbacks.Callback):
+class LearnRateBaseScheduler(k.callbacks.Callback):
def __init__(self, optimizer, verbose=False, key="lr") -> None:
super().__init__()
self._name = "LearnRateBaseScheduler"
# Optimizer
- assert isinstance(optimizer, keras.optimizers.Optimizer)
+ assert isinstance(optimizer, k.optimizers.Optimizer)
self._optimizer = optimizer
# Verbose
@@ -25,13 +25,12 @@ def on_train_begin(self, logs=None) -> None:
init_lr = K.get_value(self._optimizer.learning_rate)
self._init_lr = tf.identity(init_lr)
self._dtype = self._init_lr.dtype
- self._step = tf.cast(-1.0, self._dtype)
+ self._step = tf.cast(0.0, self._dtype)
def on_batch_begin(self, batch, logs=None) -> None:
self._step += 1.0
- K.set_value(
- self._optimizer.learning_rate, self._scheduled_lr(self._init_lr, self._step)
- )
+ sched_lr = self._scheduled_lr(self._init_lr, self._step)
+ K.set_value(self._optimizer.learning_rate, sched_lr)
def _scheduled_lr(self, init_lr, step) -> tf.Tensor:
return init_lr
@@ -51,7 +50,7 @@ def name(self) -> str:
return self._name
@property
- def optimizer(self) -> keras.optimizers.Optimizer:
+ def optimizer(self) -> k.optimizers.Optimizer:
return self._optimizer
@property
diff --git a/src/pidgan/callbacks/schedulers/LearnRateCosineDecay.py b/src/pidgan/callbacks/schedulers/k2/LearnRateCosineDecay.py
similarity index 89%
rename from src/pidgan/callbacks/schedulers/LearnRateCosineDecay.py
rename to src/pidgan/callbacks/schedulers/k2/LearnRateCosineDecay.py
index 12586a4..ac69412 100644
--- a/src/pidgan/callbacks/schedulers/LearnRateCosineDecay.py
+++ b/src/pidgan/callbacks/schedulers/k2/LearnRateCosineDecay.py
@@ -1,7 +1,7 @@
import numpy as np
import tensorflow as tf
-from pidgan.callbacks.schedulers.LearnRateBaseScheduler import LearnRateBaseScheduler
+from pidgan.callbacks.schedulers.k2.LearnRateBaseScheduler import LearnRateBaseScheduler
class LearnRateCosineDecay(LearnRateBaseScheduler):
@@ -41,10 +41,9 @@ def on_train_begin(self, logs=None) -> None:
self._tf_alpha = tf.cast(self._alpha, self._dtype)
def _scheduled_lr(self, init_lr, step) -> tf.Tensor:
- step = tf.minimum(step, self._tf_decay_steps)
p = tf.divide(step, self._tf_decay_steps)
cosine_decay = 0.5 * (1 + tf.cos(tf.constant(np.pi) * p))
- decayed = tf.multiply(1 - self._tf_alpha, cosine_decay + self._tf_alpha)
+ decayed = tf.multiply(1 - self._tf_alpha, cosine_decay) + self._tf_alpha
sched_lr = tf.multiply(init_lr, decayed)
if self._min_learning_rate is not None:
return tf.maximum(sched_lr, self._min_learning_rate)
diff --git a/src/pidgan/callbacks/schedulers/LearnRateExpDecay.py b/src/pidgan/callbacks/schedulers/k2/LearnRateExpDecay.py
similarity index 95%
rename from src/pidgan/callbacks/schedulers/LearnRateExpDecay.py
rename to src/pidgan/callbacks/schedulers/k2/LearnRateExpDecay.py
index 2467f31..1c4ad04 100644
--- a/src/pidgan/callbacks/schedulers/LearnRateExpDecay.py
+++ b/src/pidgan/callbacks/schedulers/k2/LearnRateExpDecay.py
@@ -1,6 +1,6 @@
import tensorflow as tf
-from pidgan.callbacks.schedulers.LearnRateBaseScheduler import LearnRateBaseScheduler
+from pidgan.callbacks.schedulers.k2.LearnRateBaseScheduler import LearnRateBaseScheduler
class LearnRateExpDecay(LearnRateBaseScheduler):
diff --git a/src/pidgan/callbacks/schedulers/LearnRateInvTimeDecay.py b/src/pidgan/callbacks/schedulers/k2/LearnRateInvTimeDecay.py
similarity index 95%
rename from src/pidgan/callbacks/schedulers/LearnRateInvTimeDecay.py
rename to src/pidgan/callbacks/schedulers/k2/LearnRateInvTimeDecay.py
index b563e3a..c4a01b7 100644
--- a/src/pidgan/callbacks/schedulers/LearnRateInvTimeDecay.py
+++ b/src/pidgan/callbacks/schedulers/k2/LearnRateInvTimeDecay.py
@@ -1,6 +1,6 @@
import tensorflow as tf
-from pidgan.callbacks.schedulers.LearnRateBaseScheduler import LearnRateBaseScheduler
+from pidgan.callbacks.schedulers.k2.LearnRateBaseScheduler import LearnRateBaseScheduler
class LearnRateInvTimeDecay(LearnRateBaseScheduler):
diff --git a/src/pidgan/callbacks/schedulers/LearnRatePiecewiseConstDecay.py b/src/pidgan/callbacks/schedulers/k2/LearnRatePiecewiseConstDecay.py
similarity index 93%
rename from src/pidgan/callbacks/schedulers/LearnRatePiecewiseConstDecay.py
rename to src/pidgan/callbacks/schedulers/k2/LearnRatePiecewiseConstDecay.py
index 7df354b..d052977 100644
--- a/src/pidgan/callbacks/schedulers/LearnRatePiecewiseConstDecay.py
+++ b/src/pidgan/callbacks/schedulers/k2/LearnRatePiecewiseConstDecay.py
@@ -1,7 +1,7 @@
import numpy as np
import tensorflow as tf
-from pidgan.callbacks.schedulers.LearnRateBaseScheduler import LearnRateBaseScheduler
+from pidgan.callbacks.schedulers.k2.LearnRateBaseScheduler import LearnRateBaseScheduler
class LearnRatePiecewiseConstDecay(LearnRateBaseScheduler):
diff --git a/src/pidgan/callbacks/schedulers/LearnRatePolynomialDecay.py b/src/pidgan/callbacks/schedulers/k2/LearnRatePolynomialDecay.py
similarity index 95%
rename from src/pidgan/callbacks/schedulers/LearnRatePolynomialDecay.py
rename to src/pidgan/callbacks/schedulers/k2/LearnRatePolynomialDecay.py
index b3d3c52..9dc0167 100644
--- a/src/pidgan/callbacks/schedulers/LearnRatePolynomialDecay.py
+++ b/src/pidgan/callbacks/schedulers/k2/LearnRatePolynomialDecay.py
@@ -1,6 +1,6 @@
import tensorflow as tf
-from pidgan.callbacks.schedulers.LearnRateBaseScheduler import LearnRateBaseScheduler
+from pidgan.callbacks.schedulers.k2.LearnRateBaseScheduler import LearnRateBaseScheduler
class LearnRatePolynomialDecay(LearnRateBaseScheduler):
diff --git a/src/pidgan/callbacks/schedulers/k2/__init__.py b/src/pidgan/callbacks/schedulers/k2/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/callbacks/schedulers/k3/LearnRateBaseScheduler.py b/src/pidgan/callbacks/schedulers/k3/LearnRateBaseScheduler.py
new file mode 100644
index 0000000..fee4d89
--- /dev/null
+++ b/src/pidgan/callbacks/schedulers/k3/LearnRateBaseScheduler.py
@@ -0,0 +1,59 @@
+import keras as k
+import numpy as np
+
+
+class LearnRateBaseScheduler(k.callbacks.Callback):
+ def __init__(self, optimizer, verbose=False, key="lr") -> None:
+ super().__init__()
+ self._name = "LearnRateBaseScheduler"
+
+ # Optimizer
+ assert isinstance(optimizer, k.optimizers.Optimizer)
+ self._optimizer = optimizer
+
+ # Verbose
+ assert isinstance(verbose, bool)
+ self._verbose = verbose
+
+ # Key name
+ assert isinstance(key, str)
+ self._key = key
+
+ def on_train_begin(self, logs=None) -> None:
+ self._init_lr = k.ops.copy(self._optimizer.learning_rate)
+ self._dtype = self._init_lr.dtype
+ self._step = k.ops.cast(0.0, self._dtype)
+
+ def on_batch_begin(self, batch, logs=None) -> None:
+ self._step += 1.0
+ sched_lr = self._scheduled_lr(self._init_lr, self._step)
+ self._optimizer.learning_rate = sched_lr
+
+ def _scheduled_lr(self, init_lr, step):
+ return init_lr
+
+ def on_batch_end(self, batch, logs=None) -> None:
+ logs = logs or {}
+ if self._verbose:
+ logs[self._key] = float(np.array(self._optimizer.learning_rate))
+
+ def on_epoch_end(self, epoch, logs=None) -> None:
+ logs = logs or {}
+ if self._verbose:
+ logs[self._key] = float(np.array(self._optimizer.learning_rate))
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def optimizer(self) -> k.optimizers.Optimizer:
+ return self._optimizer
+
+ @property
+ def verbose(self) -> bool:
+ return self._verbose
+
+ @property
+ def key(self) -> str:
+ return self._key
diff --git a/src/pidgan/callbacks/schedulers/k3/LearnRateCosineDecay.py b/src/pidgan/callbacks/schedulers/k3/LearnRateCosineDecay.py
new file mode 100644
index 0000000..ecab74d
--- /dev/null
+++ b/src/pidgan/callbacks/schedulers/k3/LearnRateCosineDecay.py
@@ -0,0 +1,63 @@
+import math
+import keras as k
+
+from pidgan.callbacks.schedulers.k3.LearnRateBaseScheduler import LearnRateBaseScheduler
+
+
+class LearnRateCosineDecay(LearnRateBaseScheduler):
+ def __init__(
+ self,
+ optimizer,
+ decay_steps,
+ alpha=0.0,
+ min_learning_rate=None,
+ verbose=False,
+ key="lr",
+ ) -> None:
+ super().__init__(optimizer, verbose, key)
+ self._name = "LearnRateCosineDecay"
+
+ # Decay steps
+ assert isinstance(decay_steps, (int, float))
+ assert decay_steps >= 1
+ self._decay_steps = int(decay_steps)
+
+ # Alpha
+ assert isinstance(alpha, (int, float))
+ assert (alpha) >= 0.0 and (alpha <= 1.0)
+ self._alpha = float(alpha)
+
+ # Minimum learning-rate
+ if min_learning_rate is not None:
+ assert isinstance(min_learning_rate, (int, float))
+ assert min_learning_rate > 0.0
+ self._min_learning_rate = float(min_learning_rate)
+ else:
+ self._min_learning_rate = None
+
+ def on_train_begin(self, logs=None) -> None:
+ super().on_train_begin(logs=logs)
+ self._tf_decay_steps = k.ops.cast(self._decay_steps, self._dtype)
+ self._tf_alpha = k.ops.cast(self._alpha, self._dtype)
+
+ def _scheduled_lr(self, init_lr, step):
+ p = k.ops.divide(step, self._tf_decay_steps)
+ cosine_decay = 0.5 * (1 + k.ops.cos(math.pi * p))
+ decayed = k.ops.multiply(1 - self._tf_alpha, cosine_decay) + self._tf_alpha
+ sched_lr = k.ops.multiply(init_lr, decayed)
+ if self._min_learning_rate is not None:
+ return k.ops.maximum(sched_lr, self._min_learning_rate)
+ else:
+ return sched_lr
+
+ @property
+ def decay_steps(self) -> int:
+ return self._decay_steps
+
+ @property
+ def alpha(self) -> float:
+ return self._alpha
+
+ @property
+ def min_learning_rate(self) -> float:
+ return self._min_learning_rate
diff --git a/src/pidgan/callbacks/schedulers/k3/LearnRateExpDecay.py b/src/pidgan/callbacks/schedulers/k3/LearnRateExpDecay.py
new file mode 100644
index 0000000..e6f5569
--- /dev/null
+++ b/src/pidgan/callbacks/schedulers/k3/LearnRateExpDecay.py
@@ -0,0 +1,71 @@
+import keras as k
+
+from pidgan.callbacks.schedulers.k3.LearnRateBaseScheduler import LearnRateBaseScheduler
+
+
+class LearnRateExpDecay(LearnRateBaseScheduler):
+ def __init__(
+ self,
+ optimizer,
+ decay_rate,
+ decay_steps,
+ staircase=False,
+ min_learning_rate=None,
+ verbose=False,
+ key="lr",
+ ) -> None:
+ super().__init__(optimizer, verbose, key)
+ self._name = "LearnRateExpDecay"
+
+ # Decay rate
+ assert isinstance(decay_rate, (int, float))
+ assert decay_rate > 0.0
+ self._decay_rate = float(decay_rate)
+
+ # Decay steps
+ assert isinstance(decay_steps, (int, float))
+ assert decay_steps >= 1
+ self._decay_steps = int(decay_steps)
+
+ # Staircase
+ assert isinstance(staircase, bool)
+ self._staircase = staircase
+
+ # Minimum learning-rate
+ if min_learning_rate is not None:
+ assert isinstance(min_learning_rate, (int, float))
+ assert min_learning_rate > 0.0
+ self._min_learning_rate = float(min_learning_rate)
+ else:
+ self._min_learning_rate = None
+
+ def on_train_begin(self, logs=None) -> None:
+ super().on_train_begin(logs=logs)
+ self._tf_decay_rate = k.ops.cast(self._decay_rate, self._dtype)
+ self._tf_decay_steps = k.ops.cast(self._decay_steps, self._dtype)
+
+ def _scheduled_lr(self, init_lr, step):
+ p = k.ops.divide(step, self._tf_decay_steps)
+ if self._staircase:
+ p = k.ops.floor(p)
+ sched_lr = k.ops.multiply(init_lr, k.ops.power(self._tf_decay_rate, p))
+ if self._min_learning_rate is not None:
+ return k.ops.maximum(sched_lr, self._min_learning_rate)
+ else:
+ return sched_lr
+
+ @property
+ def decay_rate(self) -> float:
+ return self._decay_rate
+
+ @property
+ def decay_steps(self) -> int:
+ return self._decay_steps
+
+ @property
+ def staircase(self) -> bool:
+ return self._staircase
+
+ @property
+ def min_learning_rate(self) -> float:
+ return self._min_learning_rate
diff --git a/src/pidgan/callbacks/schedulers/k3/LearnRateInvTimeDecay.py b/src/pidgan/callbacks/schedulers/k3/LearnRateInvTimeDecay.py
new file mode 100644
index 0000000..59b8155
--- /dev/null
+++ b/src/pidgan/callbacks/schedulers/k3/LearnRateInvTimeDecay.py
@@ -0,0 +1,71 @@
+import keras as k
+
+from pidgan.callbacks.schedulers.k3.LearnRateBaseScheduler import LearnRateBaseScheduler
+
+
+class LearnRateInvTimeDecay(LearnRateBaseScheduler):
+ def __init__(
+ self,
+ optimizer,
+ decay_rate,
+ decay_steps,
+ staircase=False,
+ min_learning_rate=None,
+ verbose=False,
+ key="lr",
+ ) -> None:
+ super().__init__(optimizer, verbose, key)
+ self._name = "LearnRateInvTimeDecay"
+
+ # Decay rate
+ assert isinstance(decay_rate, (int, float))
+ assert decay_rate > 0.0
+ self._decay_rate = float(decay_rate)
+
+ # Decay steps
+ assert isinstance(decay_steps, (int, float))
+ assert decay_steps >= 1
+ self._decay_steps = int(decay_steps)
+
+ # Staircase
+ assert isinstance(staircase, bool)
+ self._staircase = staircase
+
+ # Minimum learning-rate
+ if min_learning_rate is not None:
+ assert isinstance(min_learning_rate, (int, float))
+ assert min_learning_rate > 0.0
+ self._min_learning_rate = float(min_learning_rate)
+ else:
+ self._min_learning_rate = None
+
+ def on_train_begin(self, logs=None) -> None:
+ super().on_train_begin(logs=logs)
+ self._tf_decay_rate = k.ops.cast(self._decay_rate, self._dtype)
+ self._tf_decay_steps = k.ops.cast(self._decay_steps, self._dtype)
+
+ def _scheduled_lr(self, init_lr, step):
+ p = k.ops.divide(step, self._tf_decay_steps)
+ if self._staircase:
+ p = k.ops.floor(p)
+ sched_lr = k.ops.divide(init_lr, 1 + k.ops.multiply(self._tf_decay_rate, p))
+ if self._min_learning_rate is not None:
+ return k.ops.maximum(sched_lr, self._min_learning_rate)
+ else:
+ return sched_lr
+
+ @property
+ def decay_rate(self) -> float:
+ return self._decay_rate
+
+ @property
+ def decay_steps(self) -> int:
+ return self._decay_steps
+
+ @property
+ def staircase(self) -> bool:
+ return self._staircase
+
+ @property
+ def min_learning_rate(self) -> float:
+ return self._min_learning_rate
diff --git a/src/pidgan/callbacks/schedulers/k3/LearnRatePiecewiseConstDecay.py b/src/pidgan/callbacks/schedulers/k3/LearnRatePiecewiseConstDecay.py
new file mode 100644
index 0000000..6afbbd8
--- /dev/null
+++ b/src/pidgan/callbacks/schedulers/k3/LearnRatePiecewiseConstDecay.py
@@ -0,0 +1,38 @@
+import keras as k
+import numpy as np
+
+from pidgan.callbacks.schedulers.k3.LearnRateBaseScheduler import LearnRateBaseScheduler
+
+
+class LearnRatePiecewiseConstDecay(LearnRateBaseScheduler):
+ def __init__(self, optimizer, boundaries, values, verbose=False, key="lr") -> None:
+ super().__init__(optimizer, verbose, key)
+ self._name = "LearnRatePiecewiseConstDecay"
+
+ # Boundaries and values
+ assert isinstance(boundaries, (list, tuple, np.ndarray))
+ assert isinstance(values, (list, tuple, np.ndarray))
+ assert len(boundaries) >= 1
+ assert len(values) >= 2
+ assert len(boundaries) == len(values) - 1
+ self._boundaries = [0] + [int(b) for b in boundaries]
+ self._values = [float(v) for v in values]
+
+ def on_train_begin(self, logs=None) -> None:
+ super().on_train_begin(logs=logs)
+ self._tf_boundaries = k.ops.cast(self._boundaries, self._dtype)
+ self._tf_values = k.ops.cast(self._values, self._dtype)
+
+ def _scheduled_lr(self, init_lr, step):
+ for i in range(len(self._boundaries) - 1):
+ if (step >= self._tf_boundaries[i]) and (step < self._tf_boundaries[i + 1]):
+ return self._tf_values[i]
+ return self._tf_values[-1]
+
+ @property
+ def boundaries(self) -> list:
+ return self._boundaries
+
+ @property
+ def values(self) -> list:
+ return self._values
diff --git a/src/pidgan/callbacks/schedulers/k3/LearnRatePolynomialDecay.py b/src/pidgan/callbacks/schedulers/k3/LearnRatePolynomialDecay.py
new file mode 100644
index 0000000..67cea51
--- /dev/null
+++ b/src/pidgan/callbacks/schedulers/k3/LearnRatePolynomialDecay.py
@@ -0,0 +1,73 @@
+import keras as k
+
+from pidgan.callbacks.schedulers.k3.LearnRateBaseScheduler import LearnRateBaseScheduler
+
+
+class LearnRatePolynomialDecay(LearnRateBaseScheduler):
+ def __init__(
+ self,
+ optimizer,
+ decay_steps,
+ end_learning_rate=0.0001,
+ power=1.0,
+ cycle=False,
+ verbose=False,
+ key="lr",
+ ) -> None:
+ super().__init__(optimizer, verbose, key)
+ self._name = "LearnRatePolynomialDecay"
+
+ # Decay steps
+ assert isinstance(decay_steps, (int, float))
+ assert decay_steps >= 1
+ self._decay_steps = int(decay_steps)
+
+ # End learning-rate
+ assert isinstance(end_learning_rate, (int, float))
+ assert end_learning_rate > 0.0
+ self._end_learning_rate = float(end_learning_rate)
+
+ # Power
+ assert isinstance(power, (int, float))
+ assert power > 0.0
+ self._power = float(power)
+
+ # Cycle
+ assert isinstance(cycle, bool)
+ self._cycle = cycle
+
+ def on_train_begin(self, logs=None) -> None:
+ super().on_train_begin(logs=logs)
+ self._tf_decay_steps = k.ops.cast(self._decay_steps, self._dtype)
+ self._tf_end_learning_rate = k.ops.cast(self._end_learning_rate, self._dtype)
+ self._tf_power = k.ops.cast(self._power, self._dtype)
+
+ def _scheduled_lr(self, init_lr, step):
+ if not self._cycle:
+ step = k.ops.minimum(step, self._tf_decay_steps)
+ decay_steps = self._tf_decay_steps
+ else:
+ decay_steps = k.ops.multiply(
+ self._tf_decay_steps,
+ k.ops.ceil(k.ops.divide(step, self._tf_decay_steps)),
+ )
+ return (
+ (init_lr - self._tf_end_learning_rate)
+ * k.ops.power(1 - step / decay_steps, self._tf_power)
+ ) + self._tf_end_learning_rate
+
+ @property
+ def decay_steps(self) -> int:
+ return self._decay_steps
+
+ @property
+ def end_learning_rate(self) -> float:
+ return self._end_learning_rate
+
+ @property
+ def power(self) -> float:
+ return self._power
+
+ @property
+ def cycle(self) -> bool:
+ return self._cycle
diff --git a/src/pidgan/callbacks/schedulers/k3/__init__.py b/src/pidgan/callbacks/schedulers/k3/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/metrics/__init__.py b/src/pidgan/metrics/__init__.py
index e91e45f..c5e0f33 100644
--- a/src/pidgan/metrics/__init__.py
+++ b/src/pidgan/metrics/__init__.py
@@ -1,8 +1,24 @@
-from .Accuracy import Accuracy
-from .BinaryCrossentropy import BinaryCrossentropy
-from .JSDivergence import JSDivergence
-from .KLDivergence import KLDivergence
-from .MeanAbsoluteError import MeanAbsoluteError
-from .MeanSquaredError import MeanSquaredError
-from .RootMeanSquaredError import RootMeanSquaredError
-from .WassersteinDistance import WassersteinDistance
+import keras as k
+
+v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+
+if v_major == 3 and v_minor >= 0:
+ from .k3.BaseMetric import BaseMetric
+ from .k3.Accuracy import Accuracy
+ from .k3.BinaryCrossentropy import BinaryCrossentropy
+ from .k3.JSDivergence import JSDivergence
+ from .k3.KLDivergence import KLDivergence
+ from .k3.MeanAbsoluteError import MeanAbsoluteError
+ from .k3.MeanSquaredError import MeanSquaredError
+ from .k3.RootMeanSquaredError import RootMeanSquaredError
+ from .k3.WassersteinDistance import WassersteinDistance
+else:
+ from .k2.BaseMetric import BaseMetric
+ from .k2.Accuracy import Accuracy
+ from .k2.BinaryCrossentropy import BinaryCrossentropy
+ from .k2.JSDivergence import JSDivergence
+ from .k2.KLDivergence import KLDivergence
+ from .k2.MeanAbsoluteError import MeanAbsoluteError
+ from .k2.MeanSquaredError import MeanSquaredError
+ from .k2.RootMeanSquaredError import RootMeanSquaredError
+ from .k2.WassersteinDistance import WassersteinDistance
diff --git a/src/pidgan/metrics/Accuracy.py b/src/pidgan/metrics/k2/Accuracy.py
similarity index 80%
rename from src/pidgan/metrics/Accuracy.py
rename to src/pidgan/metrics/k2/Accuracy.py
index a74ecbe..896b98d 100644
--- a/src/pidgan/metrics/Accuracy.py
+++ b/src/pidgan/metrics/k2/Accuracy.py
@@ -1,12 +1,12 @@
import tensorflow as tf
-from tensorflow import keras
+import keras
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class Accuracy(BaseMetric):
def __init__(self, name="accuracy", dtype=None, threshold=0.5) -> None:
- super().__init__(name, dtype)
+ super().__init__(name=name, dtype=dtype)
self._accuracy = keras.metrics.BinaryAccuracy(
name=name, dtype=dtype, threshold=threshold
)
diff --git a/src/pidgan/metrics/BaseMetric.py b/src/pidgan/metrics/k2/BaseMetric.py
similarity index 62%
rename from src/pidgan/metrics/BaseMetric.py
rename to src/pidgan/metrics/k2/BaseMetric.py
index d79c7b9..b80ee02 100644
--- a/src/pidgan/metrics/BaseMetric.py
+++ b/src/pidgan/metrics/k2/BaseMetric.py
@@ -1,18 +1,21 @@
-from tensorflow import keras
+import keras as k
-class BaseMetric(keras.metrics.Metric):
+class BaseMetric(k.metrics.Metric):
def __init__(self, name="metric", dtype=None) -> None:
- super().__init__(name, dtype)
+ super().__init__(name=name, dtype=dtype)
self._metric_values = self.add_weight(
name=f"{name}_values", initializer="zeros"
)
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
raise NotImplementedError(
- "Only `BaseMetric` subclasses have the "
+ "Only the pidgan's BaseMetric subclasses have the "
"`update_state()` method implemented."
)
def result(self):
return self._metric_values
+
+ def reset_state(self):
+ self._metric_values.assign(0.0)
diff --git a/src/pidgan/metrics/BinaryCrossentropy.py b/src/pidgan/metrics/k2/BinaryCrossentropy.py
similarity index 83%
rename from src/pidgan/metrics/BinaryCrossentropy.py
rename to src/pidgan/metrics/k2/BinaryCrossentropy.py
index 9a514f6..4a6d2ea 100644
--- a/src/pidgan/metrics/BinaryCrossentropy.py
+++ b/src/pidgan/metrics/k2/BinaryCrossentropy.py
@@ -1,14 +1,14 @@
import tensorflow as tf
-from tensorflow import keras
+import keras
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class BinaryCrossentropy(BaseMetric):
def __init__(
self, name="bce", dtype=None, from_logits=False, label_smoothing=0.0
) -> None:
- super().__init__(name, dtype)
+ super().__init__(name=name, dtype=dtype)
self._bce = keras.metrics.BinaryCrossentropy(
name=name,
dtype=dtype,
diff --git a/src/pidgan/metrics/JSDivergence.py b/src/pidgan/metrics/k2/JSDivergence.py
similarity index 85%
rename from src/pidgan/metrics/JSDivergence.py
rename to src/pidgan/metrics/k2/JSDivergence.py
index b53e2a7..a6d2827 100644
--- a/src/pidgan/metrics/JSDivergence.py
+++ b/src/pidgan/metrics/k2/JSDivergence.py
@@ -1,12 +1,12 @@
import tensorflow as tf
-from tensorflow import keras
+import keras
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class JSDivergence(BaseMetric):
def __init__(self, name="js_div", dtype=None) -> None:
- super().__init__(name, dtype)
+ super().__init__(name=name, dtype=dtype)
self._kl_div = keras.metrics.KLDivergence(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
diff --git a/src/pidgan/metrics/KLDivergence.py b/src/pidgan/metrics/k2/KLDivergence.py
similarity index 75%
rename from src/pidgan/metrics/KLDivergence.py
rename to src/pidgan/metrics/k2/KLDivergence.py
index 91180a6..e60198e 100644
--- a/src/pidgan/metrics/KLDivergence.py
+++ b/src/pidgan/metrics/k2/KLDivergence.py
@@ -1,11 +1,11 @@
-from tensorflow import keras
+import keras
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class KLDivergence(BaseMetric):
def __init__(self, name="kl_div", dtype=None) -> None:
- super().__init__(name, dtype)
+ super().__init__(name=name, dtype=dtype)
self._kl_div = keras.metrics.KLDivergence(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
diff --git a/src/pidgan/metrics/MeanAbsoluteError.py b/src/pidgan/metrics/k2/MeanAbsoluteError.py
similarity index 61%
rename from src/pidgan/metrics/MeanAbsoluteError.py
rename to src/pidgan/metrics/k2/MeanAbsoluteError.py
index c9a5437..ba357a0 100644
--- a/src/pidgan/metrics/MeanAbsoluteError.py
+++ b/src/pidgan/metrics/k2/MeanAbsoluteError.py
@@ -1,11 +1,11 @@
-from tensorflow import keras
+import keras
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class MeanAbsoluteError(BaseMetric):
- def __init__(self, name="mae", dtype=None, **kwargs) -> None:
- super().__init__(name, dtype, **kwargs)
+ def __init__(self, name="mae", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
self._mae = keras.metrics.MeanAbsoluteError(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
diff --git a/src/pidgan/metrics/MeanSquaredError.py b/src/pidgan/metrics/k2/MeanSquaredError.py
similarity index 60%
rename from src/pidgan/metrics/MeanSquaredError.py
rename to src/pidgan/metrics/k2/MeanSquaredError.py
index 9f6e484..7d18bef 100644
--- a/src/pidgan/metrics/MeanSquaredError.py
+++ b/src/pidgan/metrics/k2/MeanSquaredError.py
@@ -1,11 +1,11 @@
-from tensorflow import keras
+import keras
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class MeanSquaredError(BaseMetric):
- def __init__(self, name="mse", dtype=None, **kwargs) -> None:
- super().__init__(name, dtype, **kwargs)
+ def __init__(self, name="mse", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
self._mse = keras.metrics.MeanSquaredError(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
diff --git a/src/pidgan/metrics/RootMeanSquaredError.py b/src/pidgan/metrics/k2/RootMeanSquaredError.py
similarity index 62%
rename from src/pidgan/metrics/RootMeanSquaredError.py
rename to src/pidgan/metrics/k2/RootMeanSquaredError.py
index 0f27e59..6fd238d 100644
--- a/src/pidgan/metrics/RootMeanSquaredError.py
+++ b/src/pidgan/metrics/k2/RootMeanSquaredError.py
@@ -1,11 +1,11 @@
-from tensorflow import keras
+import keras
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class RootMeanSquaredError(BaseMetric):
- def __init__(self, name="rmse", dtype=None, **kwargs):
- super().__init__(name, dtype, **kwargs)
+ def __init__(self, name="rmse", dtype=None):
+ super().__init__(name=name, dtype=dtype)
self._rmse = keras.metrics.RootMeanSquaredError(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
diff --git a/src/pidgan/metrics/WassersteinDistance.py b/src/pidgan/metrics/k2/WassersteinDistance.py
similarity index 83%
rename from src/pidgan/metrics/WassersteinDistance.py
rename to src/pidgan/metrics/k2/WassersteinDistance.py
index 51f85d6..7bbcc4d 100644
--- a/src/pidgan/metrics/WassersteinDistance.py
+++ b/src/pidgan/metrics/k2/WassersteinDistance.py
@@ -1,11 +1,11 @@
import tensorflow as tf
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics.k2.BaseMetric import BaseMetric
class WassersteinDistance(BaseMetric):
def __init__(self, name="wass_dist", dtype=None) -> None:
- super().__init__(name, dtype)
+ super().__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
if sample_weight is not None:
diff --git a/src/pidgan/metrics/k2/__init__.py b/src/pidgan/metrics/k2/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/metrics/k3/Accuracy.py b/src/pidgan/metrics/k3/Accuracy.py
new file mode 100644
index 0000000..e3069bf
--- /dev/null
+++ b/src/pidgan/metrics/k3/Accuracy.py
@@ -0,0 +1,17 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class Accuracy(BaseMetric):
+ def __init__(self, name="accuracy", dtype=None, threshold=0.5) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._accuracy = k.metrics.BinaryAccuracy(
+ name=name, dtype=dtype, threshold=threshold
+ )
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ state = self._accuracy(
+ k.ops.ones_like(y_pred), y_pred, sample_weight=sample_weight
+ )
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/BaseMetric.py b/src/pidgan/metrics/k3/BaseMetric.py
new file mode 100644
index 0000000..508f2a2
--- /dev/null
+++ b/src/pidgan/metrics/k3/BaseMetric.py
@@ -0,0 +1,21 @@
+import keras as k
+
+
+class BaseMetric(k.metrics.Metric):
+ def __init__(self, name="metric", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._metric_values = self.add_weight(
+ name=f"{name}_values", initializer="zeros"
+ )
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ raise NotImplementedError(
+ "Only the pidgan's BaseMetric subclasses have the "
+ "`update_state()` method implemented."
+ )
+
+ def result(self):
+ return self._metric_values.value
+
+ def reset_state(self):
+ self._metric_values.assign(0.0)
diff --git a/src/pidgan/metrics/k3/BinaryCrossentropy.py b/src/pidgan/metrics/k3/BinaryCrossentropy.py
new file mode 100644
index 0000000..9c49db3
--- /dev/null
+++ b/src/pidgan/metrics/k3/BinaryCrossentropy.py
@@ -0,0 +1,20 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class BinaryCrossentropy(BaseMetric):
+ def __init__(
+ self, name="bce", dtype=None, from_logits=False, label_smoothing=0.0
+ ) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._bce = k.metrics.BinaryCrossentropy(
+ name=name,
+ dtype=dtype,
+ from_logits=from_logits,
+ label_smoothing=label_smoothing,
+ )
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ state = self._bce(k.ops.ones_like(y_pred), y_pred, sample_weight=sample_weight)
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/JSDivergence.py b/src/pidgan/metrics/k3/JSDivergence.py
new file mode 100644
index 0000000..4404e09
--- /dev/null
+++ b/src/pidgan/metrics/k3/JSDivergence.py
@@ -0,0 +1,21 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class JSDivergence(BaseMetric):
+ def __init__(self, name="js_div", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._kl_div = k.metrics.KLDivergence(name=name, dtype=dtype)
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ dtype = self._kl_div(y_true, y_pred).dtype
+ y_true = k.ops.cast(y_true, dtype)
+ y_pred = k.ops.cast(y_pred, dtype)
+
+ state = 0.5 * self._kl_div(
+ y_true, 0.5 * (y_true + y_pred), sample_weight=sample_weight
+ ) + 0.5 * self._kl_div(
+ y_pred, 0.5 * (y_true + y_pred), sample_weight=sample_weight
+ )
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/KLDivergence.py b/src/pidgan/metrics/k3/KLDivergence.py
new file mode 100644
index 0000000..a4c62be
--- /dev/null
+++ b/src/pidgan/metrics/k3/KLDivergence.py
@@ -0,0 +1,13 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class KLDivergence(BaseMetric):
+ def __init__(self, name="kl_div", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._kl_div = k.metrics.KLDivergence(name=name, dtype=dtype)
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ state = self._kl_div(y_true, y_pred, sample_weight=sample_weight)
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/MeanAbsoluteError.py b/src/pidgan/metrics/k3/MeanAbsoluteError.py
new file mode 100644
index 0000000..1ba7405
--- /dev/null
+++ b/src/pidgan/metrics/k3/MeanAbsoluteError.py
@@ -0,0 +1,13 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class MeanAbsoluteError(BaseMetric):
+ def __init__(self, name="mae", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._mae = k.metrics.MeanAbsoluteError(name=name, dtype=dtype)
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ state = self._mae(y_true, y_pred, sample_weight=sample_weight)
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/MeanSquaredError.py b/src/pidgan/metrics/k3/MeanSquaredError.py
new file mode 100644
index 0000000..49c8d84
--- /dev/null
+++ b/src/pidgan/metrics/k3/MeanSquaredError.py
@@ -0,0 +1,13 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class MeanSquaredError(BaseMetric):
+ def __init__(self, name="mse", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
+ self._mse = k.metrics.MeanSquaredError(name=name, dtype=dtype)
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ state = self._mse(y_true, y_pred, sample_weight=sample_weight)
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/RootMeanSquaredError.py b/src/pidgan/metrics/k3/RootMeanSquaredError.py
new file mode 100644
index 0000000..789877a
--- /dev/null
+++ b/src/pidgan/metrics/k3/RootMeanSquaredError.py
@@ -0,0 +1,13 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class RootMeanSquaredError(BaseMetric):
+ def __init__(self, name="rmse", dtype=None):
+ super().__init__(name=name, dtype=dtype)
+ self._rmse = k.metrics.RootMeanSquaredError(name=name, dtype=dtype)
+
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ state = self._rmse(y_true, y_pred, sample_weight=sample_weight)
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/WassersteinDistance.py b/src/pidgan/metrics/k3/WassersteinDistance.py
new file mode 100644
index 0000000..98dc95d
--- /dev/null
+++ b/src/pidgan/metrics/k3/WassersteinDistance.py
@@ -0,0 +1,18 @@
+import keras as k
+
+from pidgan.metrics.k3.BaseMetric import BaseMetric
+
+
+class WassersteinDistance(BaseMetric):
+ def __init__(self, name="wass_dist", dtype=None) -> None:
+ super().__init__(name=name, dtype=dtype)
+
+ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
+ if sample_weight is not None:
+ state = k.ops.sum(sample_weight * (y_pred - y_true))
+ state /= k.ops.sum(sample_weight)
+ else:
+ state = k.ops.mean(y_pred - y_true)
+ state = k.ops.cast(state, self.dtype)
+ print("debug:", self.dtype)
+ self._metric_values.assign(state)
diff --git a/src/pidgan/metrics/k3/__init__.py b/src/pidgan/metrics/k3/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/optimization/callbacks/HopaasPruner.py b/src/pidgan/optimization/callbacks/HopaasPruner.py
index 87639e9..e82b3f4 100644
--- a/src/pidgan/optimization/callbacks/HopaasPruner.py
+++ b/src/pidgan/optimization/callbacks/HopaasPruner.py
@@ -1,7 +1,7 @@
-from tensorflow import keras
+import keras as k
-class HopaasPruner(keras.callbacks.Callback):
+class HopaasPruner(k.callbacks.Callback):
def __init__(
self, trial, loss_name, report_frequency=1, enable_pruning=True
) -> None:
diff --git a/src/pidgan/players/classifiers/AuxClassifier.py b/src/pidgan/players/classifiers/AuxClassifier.py
index 647582d..7983845 100644
--- a/src/pidgan/players/classifiers/AuxClassifier.py
+++ b/src/pidgan/players/classifiers/AuxClassifier.py
@@ -35,9 +35,21 @@ def __init__(
# Kernel regularizer
self._hidden_kernel_reg = mlp_hidden_kernel_regularizer
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ if isinstance(in_shape_2, int):
+ in_dim = in_shape_2
+ else:
+ in_dim = in_shape_1[-1] + in_shape_2[-1]
+ else:
+ in_dim = input_shape[-1] # after concatenate action
+ in_dim += len(self._aux_features)
+ return in_dim
+
def hidden_feature(self, x, return_hidden_idx=False):
raise NotImplementedError(
- "Only the `discriminators` family has the "
+ "Only the pidgan's Discriminators has the "
"`hidden_feature()` method implemented."
)
diff --git a/src/pidgan/players/classifiers/AuxMultiClassifier.py b/src/pidgan/players/classifiers/AuxMultiClassifier.py
index 0e4bade..d07ea76 100644
--- a/src/pidgan/players/classifiers/AuxMultiClassifier.py
+++ b/src/pidgan/players/classifiers/AuxMultiClassifier.py
@@ -1,4 +1,4 @@
-from tensorflow import keras
+import keras as k
from pidgan.players.discriminators import AuxDiscriminator
@@ -37,15 +37,25 @@ def __init__(
# Kernel regularizer
self._hidden_kernel_reg = mlp_hidden_kernel_regularizer
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ if isinstance(in_shape_2, int):
+ in_dim = in_shape_2
+ else:
+ in_dim = in_shape_1[-1] + in_shape_2[-1]
+ else:
+ in_dim = input_shape[-1] # after concatenate action
+ in_dim += len(self._aux_features)
+ return in_dim
+
def _define_arch(self) -> None:
super()._define_arch()
- self._out.append(
- keras.layers.Softmax(name="softmax_out" if self.name else None)
- )
+ self._out.append(k.layers.Softmax(name="softmax_out" if self.name else None))
def hidden_feature(self, x, return_hidden_idx=False):
raise NotImplementedError(
- "Only the `discriminators` family has the "
+ "Only the pidgan's Discriminators has the "
"`hidden_feature()` method implemented."
)
diff --git a/src/pidgan/players/classifiers/Classifier.py b/src/pidgan/players/classifiers/Classifier.py
index 6531d90..5ca2ba6 100644
--- a/src/pidgan/players/classifiers/Classifier.py
+++ b/src/pidgan/players/classifiers/Classifier.py
@@ -31,9 +31,19 @@ def __init__(
# Kernel regularizer
self._hidden_kernel_reg = mlp_hidden_kernel_regularizer
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ if isinstance(in_shape_2, int):
+ return in_shape_2
+ else:
+ return in_shape_1[-1] + in_shape_2[-1]
+ else:
+ return input_shape[-1] # after concatenate action
+
def hidden_feature(self, x, return_hidden_idx=False):
raise NotImplementedError(
- "Only the `discriminators` family has the "
+ "Only the pidgan's Discriminators has the "
"`hidden_feature()` method implemented."
)
diff --git a/src/pidgan/players/classifiers/MultiClassifier.py b/src/pidgan/players/classifiers/MultiClassifier.py
index e3f5812..d3418fe 100644
--- a/src/pidgan/players/classifiers/MultiClassifier.py
+++ b/src/pidgan/players/classifiers/MultiClassifier.py
@@ -1,4 +1,4 @@
-from tensorflow import keras
+import keras as k
from pidgan.players.discriminators import Discriminator
@@ -33,14 +33,23 @@ def __init__(
# Kernel regularizer
self._hidden_kernel_reg = mlp_hidden_kernel_regularizer
- def _define_arch(self) -> keras.Sequential:
- model = super()._define_arch()
- model.add(keras.layers.Softmax(name="softmax_out" if self.name else None))
- return model
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ if isinstance(in_shape_2, int):
+ return in_shape_2
+ else:
+ return in_shape_1[-1] + in_shape_2[-1]
+ else:
+ return input_shape[-1] # after concatenate action
+
+ def build(self, input_shape) -> None:
+ super().build(input_shape=input_shape)
+ self._model.add(k.layers.Softmax(name="softmax_out" if self.name else None))
def hidden_feature(self, x, return_hidden_idx=False):
raise NotImplementedError(
- "Only the `discriminators` family has the "
+ "Only the pidgan's Discriminators has the "
"`hidden_feature()` method implemented."
)
diff --git a/src/pidgan/players/classifiers/ResClassifier.py b/src/pidgan/players/classifiers/ResClassifier.py
index 01344a7..66d6807 100644
--- a/src/pidgan/players/classifiers/ResClassifier.py
+++ b/src/pidgan/players/classifiers/ResClassifier.py
@@ -31,9 +31,19 @@ def __init__(
# Kernel regularizer
self._hidden_kernel_reg = mlp_hidden_kernel_regularizer
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ if isinstance(in_shape_2, int):
+ return in_shape_2
+ else:
+ return in_shape_1[-1] + in_shape_2[-1]
+ else:
+ return input_shape[-1] # after concatenate action
+
def hidden_feature(self, x, return_hidden_idx=False):
raise NotImplementedError(
- "Only the `discriminators` family has the "
+ "Only the pidgan's Discriminators has the "
"`hidden_feature()` method implemented."
)
diff --git a/src/pidgan/players/classifiers/ResMultiClassifier.py b/src/pidgan/players/classifiers/ResMultiClassifier.py
index 0860319..f9b38c6 100644
--- a/src/pidgan/players/classifiers/ResMultiClassifier.py
+++ b/src/pidgan/players/classifiers/ResMultiClassifier.py
@@ -1,4 +1,4 @@
-from tensorflow import keras
+import keras as k
from pidgan.players.discriminators import ResDiscriminator
@@ -33,15 +33,23 @@ def __init__(
# Kernel regularizer
self._hidden_kernel_reg = mlp_hidden_kernel_regularizer
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ if isinstance(in_shape_2, int):
+ return in_shape_2
+ else:
+ return in_shape_1[-1] + in_shape_2[-1]
+ else:
+ return input_shape[-1] # after concatenate action
+
def _define_arch(self) -> None:
super()._define_arch()
- self._out.append(
- keras.layers.Softmax(name="softmax_out" if self.name else None)
- )
+ self._out.append(k.layers.Softmax(name="softmax_out" if self.name else None))
def hidden_feature(self, x, return_hidden_idx=False):
raise NotImplementedError(
- "Only the `discriminators` family has the "
+ "Only the pidgan's Discriminators has the "
"`hidden_feature()` method implemented."
)
diff --git a/src/pidgan/players/discriminators/__init__.py b/src/pidgan/players/discriminators/__init__.py
index 608a344..3f80ef8 100644
--- a/src/pidgan/players/discriminators/__init__.py
+++ b/src/pidgan/players/discriminators/__init__.py
@@ -1,3 +1,12 @@
-from .AuxDiscriminator import AuxDiscriminator
-from .Discriminator import Discriminator
-from .ResDiscriminator import ResDiscriminator
+import keras as k
+
+v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+
+if v_major == 3 and v_minor >= 0:
+ from .k3.AuxDiscriminator import AuxDiscriminator
+ from .k3.Discriminator import Discriminator
+ from .k3.ResDiscriminator import ResDiscriminator
+else:
+ from .k2.AuxDiscriminator import AuxDiscriminator
+ from .k2.Discriminator import Discriminator
+ from .k2.ResDiscriminator import ResDiscriminator
diff --git a/src/pidgan/players/discriminators/AuxDiscriminator.py b/src/pidgan/players/discriminators/k2/AuxDiscriminator.py
similarity index 92%
rename from src/pidgan/players/discriminators/AuxDiscriminator.py
rename to src/pidgan/players/discriminators/k2/AuxDiscriminator.py
index 6ff7add..d12da02 100644
--- a/src/pidgan/players/discriminators/AuxDiscriminator.py
+++ b/src/pidgan/players/discriminators/k2/AuxDiscriminator.py
@@ -1,6 +1,6 @@
import tensorflow as tf
-from pidgan.players.discriminators.ResDiscriminator import ResDiscriminator
+from pidgan.players.discriminators.k2.ResDiscriminator import ResDiscriminator
class AuxDiscriminator(ResDiscriminator):
@@ -59,6 +59,11 @@ def __init__(
)
self._aux_features.append(aux_feat)
+ def _get_input_dim(self, input_shape) -> int:
+ in_dim = super()._get_input_dim(input_shape)
+ in_dim += len(self._aux_features)
+ return in_dim
+
def _prepare_input(self, x) -> tf.Tensor:
in_feats = super()._prepare_input(x)
if isinstance(x, (list, tuple)):
diff --git a/src/pidgan/players/discriminators/Discriminator.py b/src/pidgan/players/discriminators/k2/Discriminator.py
similarity index 73%
rename from src/pidgan/players/discriminators/Discriminator.py
rename to src/pidgan/players/discriminators/k2/Discriminator.py
index 2e00ab7..5b669e4 100644
--- a/src/pidgan/players/discriminators/Discriminator.py
+++ b/src/pidgan/players/discriminators/k2/Discriminator.py
@@ -1,10 +1,11 @@
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
LEAKY_ALPHA = 0.1
-class Discriminator(keras.Model):
+class Discriminator(k.Model):
def __init__(
self,
output_dim,
@@ -16,9 +17,12 @@ def __init__(
dtype=None,
) -> None:
super().__init__(name=name, dtype=dtype)
+
+ self._model = None
+ self._model_is_built = False
+
self._hidden_activation_func = None
self._hidden_kernel_reg = None
- self._model = None
# Output dimension
assert output_dim >= 1
@@ -60,13 +64,22 @@ def __init__(
# Output activation
self._output_activation = output_activation
- def _define_arch(self) -> keras.Sequential:
- model = keras.Sequential(name=f"{self.name}_seq" if self.name else None)
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ return in_shape_1[-1] + in_shape_2[-1]
+ else:
+ return input_shape[-1] # after concat action
+
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ seq = k.Sequential(name=f"{self.name}_seq" if self.name else None)
+ seq.add(k.layers.InputLayer(input_shape=(in_dim,)))
for i, (units, rate) in enumerate(
zip(self._mlp_hidden_units, self._mlp_dropout_rates)
):
- model.add(
- keras.layers.Dense(
+ seq.add(
+ k.layers.Dense(
units=units,
activation=self._hidden_activation_func,
kernel_initializer="glorot_uniform",
@@ -77,18 +90,16 @@ def _define_arch(self) -> keras.Sequential:
)
)
if self._hidden_activation_func is None:
- model.add(
- keras.layers.LeakyReLU(
+ seq.add(
+ k.layers.LeakyReLU(
alpha=LEAKY_ALPHA, name=f"leaky_relu_{i}" if self.name else None
)
)
- model.add(
- keras.layers.Dropout(
- rate=rate, name=f"dropout_{i}" if self.name else None
- )
+ seq.add(
+ k.layers.Dropout(rate=rate, name=f"dropout_{i}" if self.name else None)
)
- model.add(
- keras.layers.Dense(
+ seq.add(
+ k.layers.Dense(
units=self._output_dim,
activation=self._output_activation,
kernel_initializer="glorot_uniform",
@@ -97,13 +108,8 @@ def _define_arch(self) -> keras.Sequential:
dtype=self.dtype,
)
)
- return model
-
- def _build_model(self, x) -> None:
- if self._model is None:
- self._model = self._define_arch()
- else:
- pass
+ self._model = seq
+ self._model_is_built = True
def _prepare_input(self, x) -> tf.Tensor:
if isinstance(x, (list, tuple)):
@@ -112,7 +118,8 @@ def _prepare_input(self, x) -> tf.Tensor:
def call(self, x) -> tf.Tensor:
x = self._prepare_input(x)
- self._build_model(x)
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
out = self._model(x)
return out
@@ -121,6 +128,8 @@ def summary(self, **kwargs) -> None:
def hidden_feature(self, x, return_hidden_idx=False):
x = self._prepare_input(x)
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
if self._hidden_activation_func is None:
multiple = 3 # dense + leaky_relu + dropout
else:
@@ -156,5 +165,18 @@ def output_activation(self):
return self._output_activation
@property
- def export_model(self) -> keras.Sequential:
+ def plain_keras(self) -> k.Sequential:
return self._model
+
+ @property
+ def export_model(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("default")
+ warnings.warn(
+ "The `export_model` attribute is deprecated and will be removed "
+ "in a future release. Consider to replace it with the new (and "
+ "equivalent) `plain_keras` attribute.",
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+ return self.plain_keras
diff --git a/src/pidgan/players/discriminators/ResDiscriminator.py b/src/pidgan/players/discriminators/k2/ResDiscriminator.py
similarity index 70%
rename from src/pidgan/players/discriminators/ResDiscriminator.py
rename to src/pidgan/players/discriminators/k2/ResDiscriminator.py
index c195755..6d1d82e 100644
--- a/src/pidgan/players/discriminators/ResDiscriminator.py
+++ b/src/pidgan/players/discriminators/k2/ResDiscriminator.py
@@ -1,5 +1,6 @@
-from tensorflow import keras
-from pidgan.players.discriminators.Discriminator import Discriminator
+import keras as k
+
+from pidgan.players.discriminators.k2.Discriminator import Discriminator
LEAKY_ALPHA = 0.1
@@ -16,10 +17,13 @@ def __init__(
dtype=None,
) -> None:
super(Discriminator, self).__init__(name=name, dtype=dtype)
+
+ self._model = None
+ self._model_is_built = False
+
self._hidden_activation_func = None
self._hidden_kernel_reg = None
self._enable_res_blocks = True
- self._model = None
# Output dimension
assert output_dim >= 1
@@ -46,9 +50,9 @@ def __init__(
def _define_arch(self) -> None:
self._hidden_layers = list()
for i in range(self._num_hidden_layers):
- seq = list()
- seq.append(
- keras.layers.Dense(
+ res_block = list()
+ res_block.append(
+ k.layers.Dense(
units=self._mlp_hidden_units,
activation=self._hidden_activation_func,
kernel_initializer="glorot_uniform",
@@ -59,27 +63,27 @@ def _define_arch(self) -> None:
)
)
if self._hidden_activation_func is None:
- seq.append(
- keras.layers.LeakyReLU(
+ res_block.append(
+ k.layers.LeakyReLU(
alpha=LEAKY_ALPHA, name=f"leaky_relu_{i}" if self.name else None
)
)
- seq.append(
- keras.layers.Dropout(
+ res_block.append(
+ k.layers.Dropout(
rate=self._mlp_dropout_rates,
name=f"dropout_{i}" if self.name else None,
)
)
- self._hidden_layers.append(seq)
+ self._hidden_layers.append(res_block)
self._add_layers = list()
for i in range(self._num_hidden_layers - 1):
self._add_layers.append(
- keras.layers.Add(name=f"add_{i}-{i+1}" if self.name else None)
+ k.layers.Add(name=f"add_{i}-{i+1}" if self.name else None)
)
self._out = [
- keras.layers.Dense(
+ k.layers.Dense(
units=self._output_dim,
activation=self._output_activation,
kernel_initializer="glorot_uniform",
@@ -89,34 +93,37 @@ def _define_arch(self) -> None:
)
]
- def _build_model(self, x) -> None:
- if self._model is None:
- self._define_arch()
- inputs = keras.layers.Input(shape=x.shape[1:])
- x_ = inputs
- for layer in self._hidden_layers[0]:
- x_ = layer(x_)
- for i in range(1, self._num_hidden_layers):
- h_ = x_
- for layer in self._hidden_layers[i]:
- h_ = layer(h_)
- if self._enable_res_blocks:
- x_ = self._add_layers[i - 1]([x_, h_])
- else:
- x_ = h_
- outputs = x_
- for layer in self._out:
- outputs = layer(outputs)
- self._model = keras.Model(
- inputs=inputs,
- outputs=outputs,
- name=f"{self.name}_func" if self.name else None,
- )
- else:
- pass
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ inputs = k.layers.Input(shape=(in_dim,))
+
+ self._define_arch()
+ x_ = inputs
+ for layer in self._hidden_layers[0]:
+ x_ = layer(x_)
+ for i in range(1, self._num_hidden_layers):
+ h_ = x_
+ for layer in self._hidden_layers[i]:
+ h_ = layer(h_)
+ if self._enable_res_blocks:
+ x_ = self._add_layers[i - 1]([x_, h_])
+ else:
+ x_ = h_
+ outputs = x_
+ for layer in self._out:
+ outputs = layer(outputs)
+
+ self._model = k.Model(
+ inputs=inputs,
+ outputs=outputs,
+ name=f"{self.name}_func" if self.name else None,
+ )
+ self._model_is_built = True
def hidden_feature(self, x, return_hidden_idx=False):
x = self._prepare_input(x)
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
for layer in self._hidden_layers[0]:
x = layer(x)
hidden_idx = int((self._num_hidden_layers + 1) / 2.0)
@@ -143,5 +150,5 @@ def mlp_dropout_rates(self) -> float:
return self._mlp_dropout_rates
@property
- def export_model(self) -> keras.Model:
+ def plain_keras(self) -> k.Model:
return self._model
diff --git a/src/pidgan/players/discriminators/k2/__init__.py b/src/pidgan/players/discriminators/k2/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/players/discriminators/k3/AuxDiscriminator.py b/src/pidgan/players/discriminators/k3/AuxDiscriminator.py
new file mode 100644
index 0000000..109fa4b
--- /dev/null
+++ b/src/pidgan/players/discriminators/k3/AuxDiscriminator.py
@@ -0,0 +1,91 @@
+import keras as k
+from pidgan.players.discriminators.k3.ResDiscriminator import ResDiscriminator
+
+
+class AuxDiscriminator(ResDiscriminator):
+ def __init__(
+ self,
+ output_dim,
+ aux_features,
+ num_hidden_layers=5,
+ mlp_hidden_units=128,
+ mlp_dropout_rates=0,
+ enable_residual_blocks=False,
+ output_activation="sigmoid",
+ name=None,
+ dtype=None,
+ ) -> None:
+ super().__init__(
+ output_dim=output_dim,
+ num_hidden_layers=num_hidden_layers,
+ mlp_hidden_units=mlp_hidden_units,
+ mlp_dropout_rates=mlp_dropout_rates,
+ output_activation=output_activation,
+ name=name,
+ dtype=dtype,
+ )
+
+ # Residual blocks
+ assert isinstance(enable_residual_blocks, bool)
+ self._enable_res_blocks = enable_residual_blocks
+
+ # Auxiliary features
+ self._aux_features = list()
+ if isinstance(aux_features, str):
+ aux_features = [aux_features]
+
+ self._aux_indices = list()
+ self._aux_operators = list()
+ for aux_feat in aux_features:
+ assert isinstance(aux_feat, str)
+ if "+" in aux_feat:
+ self._aux_operators.append(k.ops.add)
+ self._aux_indices.append([int(i) for i in aux_feat.split("+")])
+ elif "-" in aux_feat:
+ self._aux_operators.append(k.ops.subtract)
+ self._aux_indices.append([int(i) for i in aux_feat.split("-")])
+ elif "*" in aux_feat:
+ self._aux_operators.append(k.ops.multiply)
+ self._aux_indices.append([int(i) for i in aux_feat.split("*")])
+ elif "/" in aux_feat:
+ self._aux_operators.append(k.ops.divide)
+ self._aux_indices.append([int(i) for i in aux_feat.split("/")])
+ else:
+ raise ValueError(
+ f"Operator for auxiliary features not supported. "
+ f"Operators should be selected in ['+', '-', '*', '/'], "
+ f"instead '{aux_feat}' passed."
+ )
+ self._aux_features.append(aux_feat)
+
+ def _get_input_dim(self, input_shape) -> int:
+ in_dim = super()._get_input_dim(input_shape)
+ in_dim += len(self._aux_features)
+ return in_dim
+
+ def _prepare_input(self, x):
+ in_feats = super()._prepare_input(x)
+ if isinstance(x, (list, tuple)):
+ _, y = x
+ else:
+ y = x
+ aux_feats = list()
+ for aux_idx, aux_op in zip(self._aux_indices, self._aux_operators):
+ aux_feats.append(aux_op(y[:, aux_idx[0]], y[:, aux_idx[1]])[:, None])
+ self._aux_feats = k.ops.concatenate(aux_feats, axis=-1)
+ return k.ops.concatenate([in_feats, self._aux_feats], axis=-1)
+
+ def call(self, x, return_aux_features=False):
+ out = super().call(x)
+ if return_aux_features:
+ return out, self._aux_feats
+ else:
+ return out
+
+ @property
+ def aux_features(self) -> list:
+ return self._aux_features
+
+ @property
+ def enable_residual_blocks(self) -> bool:
+ return self._enable_res_blocks
diff --git a/src/pidgan/players/discriminators/k3/Discriminator.py b/src/pidgan/players/discriminators/k3/Discriminator.py
new file mode 100644
index 0000000..9da2edc
--- /dev/null
+++ b/src/pidgan/players/discriminators/k3/Discriminator.py
@@ -0,0 +1,182 @@
+import warnings
+import keras as k
+
+LEAKY_NEG_SLOPE = 0.1
+
+
+class Discriminator(k.Model):
+ def __init__(
+ self,
+ output_dim,
+ num_hidden_layers=5,
+ mlp_hidden_units=128,
+ mlp_dropout_rates=0.0,
+ output_activation="sigmoid",
+ name=None,
+ dtype=None,
+ ) -> None:
+ super().__init__(name=name, dtype=dtype)
+
+ self._model = None
+ self._model_is_built = False
+
+ self._hidden_activation_func = None
+ self._hidden_kernel_reg = None
+
+ # Output dimension
+ assert output_dim >= 1
+ self._output_dim = int(output_dim)
+
+ # Number of hidden layers
+ assert isinstance(num_hidden_layers, (int, float))
+ assert num_hidden_layers >= 1
+ self._num_hidden_layers = int(num_hidden_layers)
+
+ # Multilayer perceptron hidden units
+ if isinstance(mlp_hidden_units, (int, float)):
+ assert mlp_hidden_units >= 1
+ self._mlp_hidden_units = [int(mlp_hidden_units)] * self._num_hidden_layers
+ else:
+ mlp_hidden_units = list(mlp_hidden_units)
+ assert len(mlp_hidden_units) == self._num_hidden_layers
+ self._mlp_hidden_units = list()
+ for units in mlp_hidden_units:
+ assert isinstance(units, (int, float))
+ assert units >= 1
+ self._mlp_hidden_units.append(int(units))
+
+ # Dropout rate
+ if isinstance(mlp_dropout_rates, (int, float)):
+ assert mlp_dropout_rates >= 0.0 and mlp_dropout_rates < 1.0
+ self._mlp_dropout_rates = [
+ float(mlp_dropout_rates)
+ ] * self._num_hidden_layers
+ else:
+ mlp_dropout_rates = list(mlp_dropout_rates)
+ assert len(mlp_dropout_rates) == self._num_hidden_layers
+ self._mlp_dropout_rates = list()
+ for rate in mlp_dropout_rates:
+ assert isinstance(rate, (int, float))
+ assert rate >= 0.0 and rate < 1.0
+ self._mlp_dropout_rates.append(float(rate))
+
+ # Output activation
+ self._output_activation = output_activation
+
+ def _get_input_dim(self, input_shape) -> int:
+ if isinstance(input_shape, (list, tuple)):
+ in_shape_1, in_shape_2 = input_shape
+ return in_shape_1[-1] + in_shape_2[-1]
+ else:
+ return input_shape[-1] # after concatenate action
+
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ seq = k.Sequential(name=f"{self.name}_seq" if self.name else None)
+ seq.add(k.layers.InputLayer(shape=(in_dim,)))
+ for i, (units, rate) in enumerate(
+ zip(self._mlp_hidden_units, self._mlp_dropout_rates)
+ ):
+ seq.add(
+ k.layers.Dense(
+ units=units,
+ activation=self._hidden_activation_func,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=self._hidden_kernel_reg,
+ name=f"dense_{i}" if self.name else None,
+ dtype=self.dtype,
+ )
+ )
+ if self._hidden_activation_func is None:
+ seq.add(
+ k.layers.LeakyReLU(
+ negative_slope=LEAKY_NEG_SLOPE,
+ name=f"leaky_relu_{i}" if self.name else None,
+ )
+ )
+ seq.add(
+ k.layers.Dropout(rate=rate, name=f"dropout_{i}" if self.name else None)
+ )
+ seq.add(
+ k.layers.Dense(
+ units=self._output_dim,
+ activation=self._output_activation,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ name="dense_out" if self.name else None,
+ dtype=self.dtype,
+ )
+ )
+ self._model = seq
+ self._model_is_built = True
+
+ def _prepare_input(self, x):
+ if isinstance(x, (list, tuple)):
+ x = k.ops.concatenate(x, axis=-1)
+ return x
+
+ def call(self, x):
+ x = self._prepare_input(x)
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
+ out = self._model(x)
+ return out
+
+ def summary(self, **kwargs) -> None:
+ self._model.summary(**kwargs)
+
+ def hidden_feature(self, x, return_hidden_idx=False):
+ x = self._prepare_input(x)
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
+ if self._hidden_activation_func is None:
+ multiple = 3 # dense + leaky_relu + dropout
+ else:
+ multiple = 2 # dense + dropout
+ hidden_idx = int((self._num_hidden_layers + 1) / 2.0)
+ if hidden_idx < 1:
+ hidden_idx += 1
+ for layer in self._model.layers[: multiple * hidden_idx]:
+ x = layer(x)
+ if return_hidden_idx:
+ return x, hidden_idx
+ else:
+ return x
+
+ @property
+ def output_dim(self) -> int:
+ return self._output_dim
+
+ @property
+ def num_hidden_layers(self) -> int:
+ return self._num_hidden_layers
+
+ @property
+ def mlp_hidden_units(self) -> list:
+ return self._mlp_hidden_units
+
+ @property
+ def mlp_dropout_rates(self) -> list:
+ return self._mlp_dropout_rates
+
+ @property
+ def output_activation(self):
+ return self._output_activation
+
+ @property
+ def plain_keras(self) -> k.Sequential:
+ return self._model
+
+ @property
+ def export_model(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("default")
+ warnings.warn(
+ "The `export_model` attribute is deprecated and will be removed "
+ "in a future release. Consider to replace it with the new (and "
+ "equivalent) `plain_keras` attribute.",
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+ return self.plain_keras
diff --git a/src/pidgan/players/discriminators/k3/ResDiscriminator.py b/src/pidgan/players/discriminators/k3/ResDiscriminator.py
new file mode 100644
index 0000000..2bb32a8
--- /dev/null
+++ b/src/pidgan/players/discriminators/k3/ResDiscriminator.py
@@ -0,0 +1,154 @@
+import keras as k
+
+from pidgan.players.discriminators.k3.Discriminator import Discriminator
+
+LEAKY_NEG_SLOPE = 0.1
+
+
+class ResDiscriminator(Discriminator):
+ def __init__(
+ self,
+ output_dim,
+ num_hidden_layers=5,
+ mlp_hidden_units=128,
+ mlp_dropout_rates=0.0,
+ output_activation="sigmoid",
+ name=None,
+ dtype=None,
+ ) -> None:
+ super(Discriminator, self).__init__(name=name, dtype=dtype)
+
+ self._model = None
+ self._model_is_built = False
+
+ self._hidden_activation_func = None
+ self._hidden_kernel_reg = None
+ self._enable_res_blocks = True
+
+ # Output dimension
+ assert output_dim >= 1
+ self._output_dim = int(output_dim)
+
+ # Number of hidden layers
+ assert isinstance(num_hidden_layers, (int, float))
+ assert num_hidden_layers >= 1
+ self._num_hidden_layers = int(num_hidden_layers)
+
+ # Multilayer perceptron hidden units
+ assert isinstance(mlp_hidden_units, (int, float))
+ assert mlp_hidden_units >= 1
+ self._mlp_hidden_units = int(mlp_hidden_units)
+
+ # Dropout rate
+ assert isinstance(mlp_dropout_rates, (int, float))
+ assert mlp_dropout_rates >= 0.0 and mlp_dropout_rates < 1.0
+ self._mlp_dropout_rates = float(mlp_dropout_rates)
+
+ # Output activation
+ self._output_activation = output_activation
+
+ def _define_arch(self) -> None:
+ self._hidden_layers = list()
+ for i in range(self._num_hidden_layers):
+ res_block = list()
+ res_block.append(
+ k.layers.Dense(
+ units=self._mlp_hidden_units,
+ activation=self._hidden_activation_func,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=self._hidden_kernel_reg,
+ name=f"dense_{i}" if self.name else None,
+ dtype=self.dtype,
+ )
+ )
+ if self._hidden_activation_func is None:
+ res_block.append(
+ k.layers.LeakyReLU(
+ negative_slope=LEAKY_NEG_SLOPE,
+ name=f"leaky_relu_{i}" if self.name else None,
+ )
+ )
+ res_block.append(
+ k.layers.Dropout(
+ rate=self._mlp_dropout_rates,
+ name=f"dropout_{i}" if self.name else None,
+ )
+ )
+ self._hidden_layers.append(res_block)
+
+ self._add_layers = list()
+ for i in range(self._num_hidden_layers - 1):
+ self._add_layers.append(
+ k.layers.Add(name=f"add_{i}-{i+1}" if self.name else None)
+ )
+
+ self._out = [
+ k.layers.Dense(
+ units=self._output_dim,
+ activation=self._output_activation,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ name="dense_out" if self.name else None,
+ dtype=self.dtype,
+ )
+ ]
+
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ inputs = k.layers.Input(shape=(in_dim,))
+
+ self._define_arch()
+ x_ = inputs
+ for layer in self._hidden_layers[0]:
+ x_ = layer(x_)
+ for i in range(1, self._num_hidden_layers):
+ h_ = x_
+ for layer in self._hidden_layers[i]:
+ h_ = layer(h_)
+ if self._enable_res_blocks:
+ x_ = self._add_layers[i - 1]([x_, h_])
+ else:
+ x_ = h_
+ outputs = x_
+ for layer in self._out:
+ outputs = layer(outputs)
+ self._model = k.Model(
+ inputs=inputs,
+ outputs=outputs,
+ name=f"{self.name}_func" if self.name else None,
+ )
+ self._model_is_built = True
+
+ def hidden_feature(self, x, return_hidden_idx=False):
+ x = self._prepare_input(x)
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
+ for layer in self._hidden_layers[0]:
+ x = layer(x)
+ hidden_idx = int((self._num_hidden_layers + 1) / 2.0)
+ if hidden_idx > 1:
+ for i in range(1, hidden_idx):
+ h = x
+ for layer in self._hidden_layers[i]:
+ h = layer(h)
+ if self._enable_res_blocks:
+ x = self._add_layers[i - 1]([x, h])
+ else:
+ x = h
+ if return_hidden_idx:
+ return x, hidden_idx
+ else:
+ return x
+
+ @property
+ def mlp_hidden_units(self) -> int:
+ return self._mlp_hidden_units
+
+ @property
+ def mlp_dropout_rates(self) -> float:
+ return self._mlp_dropout_rates
+
+ @property
+ def plain_keras(self) -> k.Model:
+ return self._model
diff --git a/src/pidgan/players/discriminators/k3/__init__.py b/src/pidgan/players/discriminators/k3/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/players/generators/__init__.py b/src/pidgan/players/generators/__init__.py
index 4f5837a..08210a3 100644
--- a/src/pidgan/players/generators/__init__.py
+++ b/src/pidgan/players/generators/__init__.py
@@ -1,2 +1,10 @@
-from .Generator import Generator
-from .ResGenerator import ResGenerator
+import keras as k
+
+v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+
+if v_major == 3 and v_minor >= 0:
+ from .k3.Generator import Generator
+ from .k3.ResGenerator import ResGenerator
+else:
+ from .k2.Generator import Generator
+ from .k2.ResGenerator import ResGenerator
diff --git a/src/pidgan/players/generators/Generator.py b/src/pidgan/players/generators/k2/Generator.py
similarity index 72%
rename from src/pidgan/players/generators/Generator.py
rename to src/pidgan/players/generators/k2/Generator.py
index d100bf8..648451e 100644
--- a/src/pidgan/players/generators/Generator.py
+++ b/src/pidgan/players/generators/k2/Generator.py
@@ -1,10 +1,11 @@
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
LEAKY_ALPHA = 0.1
-class Generator(keras.Model):
+class Generator(k.Model):
def __init__(
self,
output_dim,
@@ -17,8 +18,10 @@ def __init__(
dtype=None,
) -> None:
super().__init__(name=name, dtype=dtype)
- self._hidden_activation_func = None
+
self._model = None
+ self._model_is_built = False
+ self._hidden_activation_func = None
# Output dimension
assert output_dim >= 1
@@ -64,13 +67,18 @@ def __init__(
# Output activation
self._output_activation = output_activation
- def _define_arch(self) -> keras.Sequential:
- model = keras.Sequential(name=f"{self.name}_seq" if self.name else None)
+ def _get_input_dim(self, input_shape) -> int:
+ return input_shape[-1] + self._latent_dim
+
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ seq = k.Sequential(name=f"{self.name}_seq" if self.name else None)
+ seq.add(k.layers.InputLayer(input_shape=(in_dim,)))
for i, (units, rate) in enumerate(
zip(self._mlp_hidden_units, self._mlp_dropout_rates)
):
- model.add(
- keras.layers.Dense(
+ seq.add(
+ k.layers.Dense(
units=units,
activation=self._hidden_activation_func,
kernel_initializer="glorot_uniform",
@@ -80,18 +88,16 @@ def _define_arch(self) -> keras.Sequential:
)
)
if self._hidden_activation_func is None:
- model.add(
- keras.layers.LeakyReLU(
+ seq.add(
+ k.layers.LeakyReLU(
alpha=LEAKY_ALPHA, name=f"leaky_relu_{i}" if self.name else None
)
)
- model.add(
- keras.layers.Dropout(
- rate=rate, name=f"dropout_{i}" if self.name else None
- )
+ seq.add(
+ k.layers.Dropout(rate=rate, name=f"dropout_{i}" if self.name else None)
)
- model.add(
- keras.layers.Dense(
+ seq.add(
+ k.layers.Dense(
units=self._output_dim,
activation=self._output_activation,
kernel_initializer="glorot_uniform",
@@ -100,13 +106,8 @@ def _define_arch(self) -> keras.Sequential:
dtype=self.dtype,
)
)
- return model
-
- def _build_model(self, x) -> None:
- if self._model is None:
- self._model = self._define_arch()
- else:
- pass
+ self._model = seq
+ self._model_is_built = True
def _prepare_input(self, x, seed=None) -> tuple:
latent_sample = tf.random.normal(
@@ -120,19 +121,21 @@ def _prepare_input(self, x, seed=None) -> tuple:
return x, latent_sample
def call(self, x) -> tf.Tensor:
- x, _ = self._prepare_input(x, seed=None)
- self._build_model(x)
- out = self._model(x)
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
+ in_, _ = self._prepare_input(x, seed=None)
+ out = self._model(in_)
return out
def summary(self, **kwargs) -> None:
self._model.summary(**kwargs)
def generate(self, x, seed=None, return_latent_sample=False) -> tf.Tensor:
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
tf.random.set_seed(seed=seed)
- x, latent_sample = self._prepare_input(x, seed=seed)
- self._build_model(x)
- out = self._model(x)
+ in_, latent_sample = self._prepare_input(x, seed=seed)
+ out = self._model(in_)
if return_latent_sample:
return out, latent_sample
else:
@@ -163,5 +166,18 @@ def output_activation(self):
return self._output_activation
@property
- def export_model(self) -> keras.Sequential:
+ def plain_keras(self) -> k.Sequential:
return self._model
+
+ @property
+ def export_model(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("default")
+ warnings.warn(
+ "The `export_model` attribute is deprecated and will be removed "
+ "in a future release. Consider to replace it with the new (and "
+ "equivalent) `plain_keras` attribute.",
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+ return self.plain_keras
diff --git a/src/pidgan/players/generators/ResGenerator.py b/src/pidgan/players/generators/k2/ResGenerator.py
similarity index 67%
rename from src/pidgan/players/generators/ResGenerator.py
rename to src/pidgan/players/generators/k2/ResGenerator.py
index eb30faf..c6f417f 100644
--- a/src/pidgan/players/generators/ResGenerator.py
+++ b/src/pidgan/players/generators/k2/ResGenerator.py
@@ -1,5 +1,6 @@
-from tensorflow import keras
-from pidgan.players.generators.Generator import Generator
+import keras as k
+
+from pidgan.players.generators.k2.Generator import Generator
LEAKY_ALPHA = 0.1
@@ -17,9 +18,12 @@ def __init__(
dtype=None,
) -> None:
super(Generator, self).__init__(name=name, dtype=dtype)
+
+ self._model = None
+ self._model_is_built = False
+
self._hidden_activation_func = None
self._enable_res_blocks = True
- self._model = None
# Output dimension
assert output_dim >= 1
@@ -50,9 +54,9 @@ def __init__(
def _define_arch(self) -> None:
self._hidden_layers = list()
for i in range(self._num_hidden_layers):
- seq = list()
- seq.append(
- keras.layers.Dense(
+ res_block = list()
+ res_block.append(
+ k.layers.Dense(
units=self._mlp_hidden_units,
activation=self._hidden_activation_func,
kernel_initializer="glorot_uniform",
@@ -62,26 +66,26 @@ def _define_arch(self) -> None:
)
)
if self._hidden_activation_func is None:
- seq.append(
- keras.layers.LeakyReLU(
+ res_block.append(
+ k.layers.LeakyReLU(
alpha=LEAKY_ALPHA, name=f"leaky_relu_{i}" if self.name else None
)
)
- seq.append(
- keras.layers.Dropout(
+ res_block.append(
+ k.layers.Dropout(
rate=self._mlp_dropout_rates,
name=f"dropout_{i}" if self.name else None,
)
)
- self._hidden_layers.append(seq)
+ self._hidden_layers.append(res_block)
self._add_layers = list()
for i in range(self._num_hidden_layers - 1):
self._add_layers.append(
- keras.layers.Add(name=f"add_{i}-{i+1}" if self.name else None)
+ k.layers.Add(name=f"add_{i}-{i+1}" if self.name else None)
)
- self._out = keras.layers.Dense(
+ self._out = k.layers.Dense(
units=self._output_dim,
activation=self._output_activation,
kernel_initializer="glorot_uniform",
@@ -90,29 +94,28 @@ def _define_arch(self) -> None:
dtype=self.dtype,
)
- def _build_model(self, x) -> None:
- if self._model is None:
- self._define_arch()
- inputs = keras.layers.Input(shape=x.shape[1:])
- x_ = inputs
- for layer in self._hidden_layers[0]:
- x_ = layer(x_)
- for i in range(1, self._num_hidden_layers):
- h_ = x_
- for layer in self._hidden_layers[i]:
- h_ = layer(h_)
- if self._enable_res_blocks:
- x_ = self._add_layers[i - 1]([x_, h_])
- else:
- x_ = h_
- outputs = self._out(x_)
- self._model = keras.Model(
- inputs=inputs,
- outputs=outputs,
- name=f"{self.name}_func" if self.name else None,
- )
- else:
- pass
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ self._define_arch()
+ inputs = k.layers.Input(shape=(in_dim,))
+ x_ = inputs
+ for layer in self._hidden_layers[0]:
+ x_ = layer(x_)
+ for i in range(1, self._num_hidden_layers):
+ h_ = x_
+ for layer in self._hidden_layers[i]:
+ h_ = layer(h_)
+ if self._enable_res_blocks:
+ x_ = self._add_layers[i - 1]([x_, h_])
+ else:
+ x_ = h_
+ outputs = self._out(x_)
+ self._model = k.Model(
+ inputs=inputs,
+ outputs=outputs,
+ name=f"{self.name}_func" if self.name else None,
+ )
+ self._model_is_built = True
@property
def mlp_hidden_units(self) -> int:
@@ -123,5 +126,5 @@ def mlp_dropout_rates(self) -> float:
return self._mlp_dropout_rates
@property
- def export_model(self) -> keras.Model:
+ def plain_keras(self) -> k.Model:
return self._model
diff --git a/src/pidgan/players/generators/k2/__init__.py b/src/pidgan/players/generators/k2/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/players/generators/k3/Generator.py b/src/pidgan/players/generators/k3/Generator.py
new file mode 100644
index 0000000..036fa7f
--- /dev/null
+++ b/src/pidgan/players/generators/k3/Generator.py
@@ -0,0 +1,183 @@
+import warnings
+import keras as k
+
+LEAKY_NEG_SLOPE = 0.1
+
+
+class Generator(k.Model):
+ def __init__(
+ self,
+ output_dim,
+ latent_dim,
+ num_hidden_layers=5,
+ mlp_hidden_units=128,
+ mlp_dropout_rates=0.0,
+ output_activation=None,
+ name=None,
+ dtype=None,
+ ) -> None:
+ super().__init__(name=name, dtype=dtype)
+
+ self._model = None
+ self._model_is_built = False
+ self._hidden_activation_func = None
+
+ # Output dimension
+ assert output_dim >= 1
+ self._output_dim = int(output_dim)
+
+ # Latent space dimension
+ assert latent_dim >= 1
+ self._latent_dim = int(latent_dim)
+
+ # Number of hidden layers
+ assert isinstance(num_hidden_layers, (int, float))
+ assert num_hidden_layers >= 1
+ self._num_hidden_layers = int(num_hidden_layers)
+
+ # Multilayer perceptron hidden units
+ if isinstance(mlp_hidden_units, (int, float)):
+ assert mlp_hidden_units >= 1
+ self._mlp_hidden_units = [int(mlp_hidden_units)] * self._num_hidden_layers
+ else:
+ mlp_hidden_units = list(mlp_hidden_units)
+ assert len(mlp_hidden_units) == self._num_hidden_layers
+ self._mlp_hidden_units = list()
+ for units in mlp_hidden_units:
+ assert isinstance(units, (int, float))
+ assert units >= 1
+ self._mlp_hidden_units.append(int(units))
+
+ # Dropout rate
+ if isinstance(mlp_dropout_rates, (int, float)):
+ assert mlp_dropout_rates >= 0.0 and mlp_dropout_rates < 1.0
+ self._mlp_dropout_rates = [
+ float(mlp_dropout_rates)
+ ] * self._num_hidden_layers
+ else:
+ mlp_dropout_rates = list(mlp_dropout_rates)
+ assert len(mlp_dropout_rates) == self._num_hidden_layers
+ self._mlp_dropout_rates = list()
+ for rate in mlp_dropout_rates:
+ assert isinstance(rate, (int, float))
+ assert rate >= 0.0 and rate < 1.0
+ self._mlp_dropout_rates.append(float(rate))
+
+ # Output activation
+ self._output_activation = output_activation
+
+ def _get_input_dim(self, input_shape) -> int:
+ return input_shape[-1] + self._latent_dim
+
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ seq = k.Sequential(name=f"{self.name}_seq" if self.name else None)
+ seq.add(k.layers.InputLayer(shape=(in_dim,)))
+ for i, (units, rate) in enumerate(
+ zip(self._mlp_hidden_units, self._mlp_dropout_rates)
+ ):
+ seq.add(
+ k.layers.Dense(
+ units=units,
+ activation=self._hidden_activation_func,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ name=f"dense_{i}" if self.name else None,
+ dtype=self.dtype,
+ )
+ )
+ if self._hidden_activation_func is None:
+ seq.add(
+ k.layers.LeakyReLU(
+ negative_slope=LEAKY_NEG_SLOPE,
+ name=f"leaky_relu_{i}" if self.name else None,
+ )
+ )
+ seq.add(
+ k.layers.Dropout(rate=rate, name=f"dropout_{i}" if self.name else None)
+ )
+ seq.add(
+ k.layers.Dense(
+ units=self._output_dim,
+ activation=self._output_activation,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ name="dense_out" if self.name else None,
+ dtype=self.dtype,
+ )
+ )
+ self._model = seq
+ self._model_is_built = True
+
+ def _prepare_input(self, x, seed=None) -> tuple:
+ latent_sample = k.random.normal(
+ shape=(k.ops.shape(x)[0], self._latent_dim),
+ mean=0.0,
+ stddev=1.0,
+ dtype=self.dtype,
+ seed=seed,
+ )
+ x = k.ops.concatenate([x, latent_sample], axis=-1)
+ return x, latent_sample
+
+ def call(self, x):
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
+ in_, _ = self._prepare_input(x, seed=None)
+ out = self._model(in_)
+ return out
+
+ def summary(self, **kwargs) -> None:
+ self._model.summary(**kwargs)
+
+ def generate(self, x, seed=None, return_latent_sample=False):
+ if not self._model_is_built:
+ self.build(input_shape=x.shape)
+ seed_gen = k.random.SeedGenerator(seed=seed)
+ in_, latent_sample = self._prepare_input(x, seed=seed_gen)
+ out = self._model(in_)
+ if return_latent_sample:
+ return out, latent_sample
+ else:
+ return out
+
+ @property
+ def output_dim(self) -> int:
+ return self._output_dim
+
+ @property
+ def latent_dim(self) -> int:
+ return self._latent_dim
+
+ @property
+ def num_hidden_layers(self) -> int:
+ return self._num_hidden_layers
+
+ @property
+ def mlp_hidden_units(self) -> list:
+ return self._mlp_hidden_units
+
+ @property
+ def mlp_dropout_rates(self) -> list:
+ return self._mlp_dropout_rates
+
+ @property
+ def output_activation(self):
+ return self._output_activation
+
+ @property
+ def plain_keras(self) -> k.Sequential:
+ return self._model
+
+ @property
+ def export_model(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("default")
+ warnings.warn(
+ "The `export_model` attribute is deprecated and will be removed "
+ "in a future release. Consider to replace it with the new (and "
+ "equivalent) `plain_keras` attribute.",
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+ return self.plain_keras
diff --git a/src/pidgan/players/generators/k3/ResGenerator.py b/src/pidgan/players/generators/k3/ResGenerator.py
new file mode 100644
index 0000000..b772fd4
--- /dev/null
+++ b/src/pidgan/players/generators/k3/ResGenerator.py
@@ -0,0 +1,131 @@
+import keras as k
+
+from pidgan.players.generators.k3.Generator import Generator
+
+LEAKY_NEG_SLOPE = 0.1
+
+
+class ResGenerator(Generator):
+ def __init__(
+ self,
+ output_dim,
+ latent_dim,
+ num_hidden_layers=5,
+ mlp_hidden_units=128,
+ mlp_dropout_rates=0.0,
+ output_activation=None,
+ name=None,
+ dtype=None,
+ ) -> None:
+ super(Generator, self).__init__(name=name, dtype=dtype)
+
+ self._model = None
+ self._model_is_built = False
+
+ self._hidden_activation_func = None
+ self._enable_res_blocks = True
+
+ # Output dimension
+ assert output_dim >= 1
+ self._output_dim = int(output_dim)
+
+ # Latent space dimension
+ assert latent_dim >= 1
+ self._latent_dim = int(latent_dim)
+
+ # Number of hidden layers
+ assert isinstance(num_hidden_layers, (int, float))
+ assert num_hidden_layers >= 1
+ self._num_hidden_layers = int(num_hidden_layers)
+
+ # Multilayer perceptron hidden units
+ assert isinstance(mlp_hidden_units, (int, float))
+ assert mlp_hidden_units >= 1
+ self._mlp_hidden_units = int(mlp_hidden_units)
+
+ # Dropout rate
+ assert isinstance(mlp_dropout_rates, (int, float))
+ assert mlp_dropout_rates >= 0.0 and mlp_dropout_rates < 1.0
+ self._mlp_dropout_rates = float(mlp_dropout_rates)
+
+ # Output activation
+ self._output_activation = output_activation
+
+ def _define_arch(self) -> None:
+ self._hidden_layers = list()
+ for i in range(self._num_hidden_layers):
+ res_block = list()
+ res_block.append(
+ k.layers.Dense(
+ units=self._mlp_hidden_units,
+ activation=self._hidden_activation_func,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ name=f"dense_{i}" if self.name else None,
+ dtype=self.dtype,
+ )
+ )
+ if self._hidden_activation_func is None:
+ res_block.append(
+ k.layers.LeakyReLU(
+ negative_slope=LEAKY_NEG_SLOPE,
+ name=f"leaky_relu_{i}" if self.name else None,
+ )
+ )
+ res_block.append(
+ k.layers.Dropout(
+ rate=self._mlp_dropout_rates,
+ name=f"dropout_{i}" if self.name else None,
+ )
+ )
+ self._hidden_layers.append(res_block)
+
+ self._add_layers = list()
+ for i in range(self._num_hidden_layers - 1):
+ self._add_layers.append(
+ k.layers.Add(name=f"add_{i}-{i+1}" if self.name else None)
+ )
+
+ self._out = k.layers.Dense(
+ units=self._output_dim,
+ activation=self._output_activation,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ name="dense_out" if self.name else None,
+ dtype=self.dtype,
+ )
+
+ def build(self, input_shape) -> None:
+ in_dim = self._get_input_dim(input_shape)
+ self._define_arch()
+ inputs = k.layers.Input(shape=(in_dim,))
+ x_ = inputs
+ for layer in self._hidden_layers[0]:
+ x_ = layer(x_)
+ for i in range(1, self._num_hidden_layers):
+ h_ = x_
+ for layer in self._hidden_layers[i]:
+ h_ = layer(h_)
+ if self._enable_res_blocks:
+ x_ = self._add_layers[i - 1]([x_, h_])
+ else:
+ x_ = h_
+ outputs = self._out(x_)
+ self._model = k.Model(
+ inputs=inputs,
+ outputs=outputs,
+ name=f"{self.name}_func" if self.name else None,
+ )
+ self._model_is_built = True
+
+ @property
+ def mlp_hidden_units(self) -> int:
+ return self._mlp_hidden_units
+
+ @property
+ def mlp_dropout_rates(self) -> float:
+ return self._mlp_dropout_rates
+
+ @property
+ def plain_keras(self) -> k.Model:
+ return self._model
diff --git a/src/pidgan/players/generators/k3/__init__.py b/src/pidgan/players/generators/k3/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/pidgan/utils/checks/checkMetrics.py b/src/pidgan/utils/checks/checkMetrics.py
index f780a9d..584b50d 100644
--- a/src/pidgan/utils/checks/checkMetrics.py
+++ b/src/pidgan/utils/checks/checkMetrics.py
@@ -6,7 +6,7 @@
from pidgan.metrics import MeanSquaredError as MSE
from pidgan.metrics import RootMeanSquaredError as RMSE
from pidgan.metrics import WassersteinDistance as Wass_dist
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics import BaseMetric
METRIC_SHORTCUTS = [
"accuracy",
@@ -39,11 +39,11 @@ def checkMetrics(metrics): # TODO: add Union[list, None]
for metric in metrics:
if isinstance(metric, str):
if metric in METRIC_SHORTCUTS:
- for str_metric, calo_metric in zip(
+ for str_metric, pid_metric in zip(
METRIC_SHORTCUTS, PIDGAN_METRICS
):
if metric == str_metric:
- checked_metrics.append(calo_metric)
+ checked_metrics.append(pid_metric)
else:
raise ValueError(
f"`metrics` elements should be selected in "
@@ -53,12 +53,12 @@ def checkMetrics(metrics): # TODO: add Union[list, None]
checked_metrics.append(metric)
else:
raise TypeError(
- f"`metrics` elements should be a pidgan's "
- f"`BaseMetric`, instead {type(metric)} passed"
+ f"`metrics` elements should inherit from pidgan's "
+ f"BaseMetric, instead {type(metric)} passed"
)
return checked_metrics
else:
raise TypeError(
f"`metrics` should be a list of strings or pidgan's "
- f"`BaseMetric`s, instead {type(metrics)} passed"
+ f"metrics, instead {type(metrics)} passed"
)
diff --git a/src/pidgan/utils/checks/checkOptimizer.py b/src/pidgan/utils/checks/checkOptimizer.py
index f719cf2..e5a39c7 100644
--- a/src/pidgan/utils/checks/checkOptimizer.py
+++ b/src/pidgan/utils/checks/checkOptimizer.py
@@ -1,14 +1,14 @@
-from tensorflow import keras
+import keras as k
OPT_SHORTCUTS = ["sgd", "rmsprop", "adam"]
TF_OPTIMIZERS = [
- keras.optimizers.SGD(),
- keras.optimizers.RMSprop(),
- keras.optimizers.Adam(),
+ k.optimizers.SGD(),
+ k.optimizers.RMSprop(),
+ k.optimizers.Adam(),
]
-def checkOptimizer(optimizer) -> keras.optimizers.Optimizer:
+def checkOptimizer(optimizer) -> k.optimizers.Optimizer:
if isinstance(optimizer, str):
if optimizer in OPT_SHORTCUTS:
for opt, tf_opt in zip(OPT_SHORTCUTS, TF_OPTIMIZERS):
@@ -19,10 +19,10 @@ def checkOptimizer(optimizer) -> keras.optimizers.Optimizer:
f"`optimizer` should be selected in {OPT_SHORTCUTS}, "
f"instead '{optimizer}' passed"
)
- elif isinstance(optimizer, keras.optimizers.Optimizer):
+ elif isinstance(optimizer, k.optimizers.Optimizer):
return optimizer
else:
raise TypeError(
- f"`optimizer` should be a TensorFlow `Optimizer`, "
+ f"`optimizer` should be a Keras' Optimizer, "
f"instead {type(optimizer)} passed"
)
diff --git a/src/pidgan/utils/reports/HPSingleton.py b/src/pidgan/utils/reports/HPSingleton.py
index 959f468..568590a 100644
--- a/src/pidgan/utils/reports/HPSingleton.py
+++ b/src/pidgan/utils/reports/HPSingleton.py
@@ -12,7 +12,7 @@ def update(self, **kwargs) -> None:
for key in kwargs.keys():
if key in self._used_keys:
raise KeyError(
- f"The hyperparameter {key} was already used and is now read-only"
+ f"The hyperparameter '{key}' was already used and is now read-only"
)
self._hparams.update(kwargs)
@@ -29,7 +29,9 @@ def clean(self) -> None:
def __del__(self) -> None:
for key in self._hparams.keys():
if key not in self._used_keys:
- print(f"[WARNING] The hyperparameter {key} was defined but never used")
+ print(
+ f"[WARNING] The hyperparameter '{key}' was defined but never used"
+ )
print(self._used_keys)
def __str__(self) -> str:
diff --git a/src/pidgan/utils/reports/getSummaryHTML.py b/src/pidgan/utils/reports/getSummaryHTML.py
index 94d96a7..88be5dd 100644
--- a/src/pidgan/utils/reports/getSummaryHTML.py
+++ b/src/pidgan/utils/reports/getSummaryHTML.py
@@ -11,16 +11,14 @@ def getSummaryHTML(model) -> tuple:
for layer in model.layers:
layer_type = f"{layer.name} ({layer.__class__.__name__}) | \n"
try:
- output_shape = f"{layer.get_output_at(0).get_shape()} | \n"
- except RuntimeError:
+ output_shape = f"{layer.output.shape} | \n"
+ except AttributeError:
output_shape = "None | \n" # print "None" in case of errors
num_params = f"{layer.count_params()} | \n"
rows.append("\n" + layer_type + output_shape + num_params + "
\n")
- train_params += int(
- np.sum([np.prod(v.get_shape()) for v in layer.trainable_weights])
- )
+ train_params += int(np.sum([np.prod(v.shape) for v in layer.trainable_weights]))
nontrain_params += int(
- np.sum([np.prod(v.get_shape()) for v in layer.non_trainable_weights])
+ np.sum([np.prod(v.shape) for v in layer.non_trainable_weights])
)
rows_html = "".join([f"{r}" for r in rows])
diff --git a/src/pidgan/version.py b/src/pidgan/version.py
index ae73625..d3ec452 100644
--- a/src/pidgan/version.py
+++ b/src/pidgan/version.py
@@ -1 +1 @@
-__version__ = "0.1.3"
+__version__ = "0.2.0"
diff --git a/tests/algorithms/test_BceGAN.py b/tests/algorithms/test_BceGAN.py
index 35d42ac..9001fc7 100644
--- a/tests/algorithms/test_BceGAN.py
+++ b/tests/algorithms/test_BceGAN.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import BinaryCrossentropy as BCE
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -86,68 +92,98 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + bce
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_train(referee, sample_weight):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, build_first):
from pidgan.algorithms import BceGAN
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = BceGAN(
@@ -159,28 +195,60 @@ def test_model_train(referee, sample_weight):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["bce"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + bce)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + bce)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_BceGAN_ALP.py b/tests/algorithms/test_BceGAN_ALP.py
index d8ad527..c81ca80 100644
--- a/tests/algorithms/test_BceGAN_ALP.py
+++ b/tests/algorithms/test_BceGAN_ALP.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import BinaryCrossentropy as BCE
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -83,71 +89,101 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- virtual_adv_direction_upds=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ virtual_adv_direction_upds=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
assert isinstance(model.virtual_adv_direction_upds, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + bce
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
@pytest.mark.parametrize("lipschitz_penalty_strategy", ["two-sided", "one-sided"])
-def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, lipschitz_penalty_strategy, build_first):
from pidgan.algorithms import BceGAN_ALP
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = BceGAN_ALP(
@@ -158,29 +194,61 @@ def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- virtual_adv_direction_upds=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["bce"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ virtual_adv_direction_upds=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + bce)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + bce)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_BceGAN_GP.py b/tests/algorithms/test_BceGAN_GP.py
index a0ee670..aab7367 100644
--- a/tests/algorithms/test_BceGAN_GP.py
+++ b/tests/algorithms/test_BceGAN_GP.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import BinaryCrossentropy as BCE
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -83,69 +89,99 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + bce
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
@pytest.mark.parametrize("lipschitz_penalty_strategy", ["two-sided", "one-sided"])
-def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, lipschitz_penalty_strategy, build_first):
from pidgan.algorithms import BceGAN_GP
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = BceGAN_GP(
@@ -156,28 +192,60 @@ def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["bce"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + bce)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + bce)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_CramerGAN.py b/tests/algorithms/test_CramerGAN.py
index ba3f19f..3665a69 100644
--- a/tests/algorithms/test_CramerGAN.py
+++ b/tests/algorithms/test_CramerGAN.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import WassersteinDistance as Wass_dist
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -22,7 +28,7 @@
)
disc = AuxDiscriminator(
- output_dim=1,
+ output_dim=32, # > 1 for CramerGAN system
aux_features=["0 + 1", "2 - 3"],
num_hidden_layers=4,
mlp_hidden_units=32,
@@ -84,11 +90,11 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_output, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_output, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
@@ -97,56 +103,86 @@ def test_model_use(referee):
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["bce"], [Wass_dist()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + bce
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
@pytest.mark.parametrize("lipschitz_penalty_strategy", ["two-sided", "one-sided"])
-def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, lipschitz_penalty_strategy, build_first):
from pidgan.algorithms import CramerGAN
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = CramerGAN(
@@ -156,28 +192,60 @@ def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
lipschitz_penalty_strategy=lipschitz_penalty_strategy,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["wass_dist"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + wass_dist)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + wass_dist)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["wass_dist"], [Wass_dist()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_GAN.py b/tests/algorithms/test_GAN.py
index 7fcde4c..b98bee8 100644
--- a/tests/algorithms/test_GAN.py
+++ b/tests/algorithms/test_GAN.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import BinaryCrossentropy as BCE
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -83,69 +89,99 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + bce
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
@pytest.mark.parametrize("use_original_loss", [True, False])
-def test_model_train(referee, sample_weight, use_original_loss):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, use_original_loss, build_first):
from pidgan.algorithms import GAN
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = GAN(
@@ -156,28 +192,60 @@ def test_model_train(referee, sample_weight, use_original_loss):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["bce"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + bce)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + bce)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_LSGAN.py b/tests/algorithms/test_LSGAN.py
index 061e096..631b21e 100644
--- a/tests/algorithms/test_LSGAN.py
+++ b/tests/algorithms/test_LSGAN.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import BinaryCrossentropy as BCE
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -83,69 +89,99 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + bce
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
@pytest.mark.parametrize("minimize_pearson_chi2", [True, False])
-def test_model_train(referee, sample_weight, minimize_pearson_chi2):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, minimize_pearson_chi2, build_first):
from pidgan.algorithms import LSGAN
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = LSGAN(
@@ -156,28 +192,60 @@ def test_model_train(referee, sample_weight, minimize_pearson_chi2):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["bce"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + bce)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + bce)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["bce"], [BCE()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_WGAN.py b/tests/algorithms/test_WGAN.py
index d83a03f..e8e8f70 100644
--- a/tests/algorithms/test_WGAN.py
+++ b/tests/algorithms/test_WGAN.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import WassersteinDistance as Wass_dist
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -80,68 +86,98 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
-
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
+
+
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["wass_dist"], [Wass_dist()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + wass_dist
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_train(referee, sample_weight):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, build_first):
from pidgan.algorithms import WGAN
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = WGAN(
@@ -151,28 +187,60 @@ def test_model_train(referee, sample_weight):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["wass_dist"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + wass_dist)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + wass_dist)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+@pytest.mark.parametrize("metrics", [["wass_dist"], [Wass_dist()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_WGAN_ALP.py b/tests/algorithms/test_WGAN_ALP.py
index 944bcc4..07811ba 100644
--- a/tests/algorithms/test_WGAN_ALP.py
+++ b/tests/algorithms/test_WGAN_ALP.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import WassersteinDistance as Wass_dist
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -84,71 +90,101 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- virtual_adv_direction_upds=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["wass_dist"], [Wass_dist()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ virtual_adv_direction_upds=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
assert isinstance(model.virtual_adv_direction_upds, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + wass_dist
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
@pytest.mark.parametrize("lipschitz_penalty_strategy", ["two-sided", "one-sided"])
-def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, lipschitz_penalty_strategy, build_first):
from pidgan.algorithms import WGAN_ALP
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = WGAN_ALP(
@@ -159,29 +195,61 @@ def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- virtual_adv_direction_upds=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["wass_dist"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ virtual_adv_direction_upds=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + wass_dist)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + wass_dist)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["wass_dist"], [Wass_dist()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/algorithms/test_WGAN_GP.py b/tests/algorithms/test_WGAN_GP.py
index 8481d85..5076ab5 100644
--- a/tests/algorithms/test_WGAN_GP.py
+++ b/tests/algorithms/test_WGAN_GP.py
@@ -1,12 +1,18 @@
+import os
import pytest
+import warnings
+import keras as k
import tensorflow as tf
-from tensorflow import keras
from pidgan.players.classifiers import AuxClassifier
from pidgan.players.discriminators import AuxDiscriminator
from pidgan.players.generators import ResGenerator
+from pidgan.metrics import WassersteinDistance as Wass_dist
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
x = tf.random.normal(shape=(CHUNK_SIZE, 4))
y = tf.random.normal(shape=(CHUNK_SIZE, 8))
@@ -84,69 +90,99 @@ def test_model_use(referee):
feature_matching_penalty=0.0,
referee=referee,
)
- outputs = model(x, y)
+ out = model(x, y)
if referee is not None:
- g_output, d_outputs, r_outputs = outputs
+ g_out, d_out, r_out = out
else:
- g_output, d_outputs = outputs
+ g_out, d_out = out
model.summary()
test_g_shape = [y.shape[0]]
test_g_shape.append(model.generator.output_dim)
- assert g_output.shape == tuple(test_g_shape)
+ assert g_out.shape == tuple(test_g_shape)
test_d_shape = [y.shape[0]]
test_d_shape.append(model.discriminator.output_dim)
- d_output_gen, d_output_ref = d_outputs
- assert d_output_gen.shape == tuple(test_d_shape)
- assert d_output_ref.shape == tuple(test_d_shape)
+ d_out_gen, d_out_ref = d_out
+ assert d_out_gen.shape == tuple(test_d_shape)
+ assert d_out_ref.shape == tuple(test_d_shape)
if referee is not None:
test_r_shape = [y.shape[0]]
test_r_shape.append(model.referee.output_dim)
- r_output_gen, r_output_ref = r_outputs
- assert r_output_gen.shape == tuple(test_r_shape)
- assert r_output_ref.shape == tuple(test_r_shape)
+ r_out_gen, r_out_ref = r_out
+ assert r_out_gen.shape == tuple(test_r_shape)
+ assert r_out_ref.shape == tuple(test_r_shape)
-@pytest.mark.parametrize("metrics", [["bce"], None])
-def test_model_compilation(model, metrics):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=metrics,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
- )
+@pytest.mark.parametrize("build_first", [True, False])
+@pytest.mark.parametrize("metrics", [["wass_dist"], [Wass_dist()], None])
+def test_model_compilation(model, build_first, metrics):
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings(record=True) as w:
+ if build_first:
+ warnings.simplefilter("always")
+ else:
+ warnings.simplefilter("ignore")
+
+ model.compile(
+ metrics=metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if build_first and metrics is not None:
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "`compile()`" in str(w[-1].message)
+
assert isinstance(model.metrics, list)
- assert isinstance(model.generator_optimizer, keras.optimizers.Optimizer)
- assert isinstance(model.discriminator_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.generator_optimizer, k.optimizers.Optimizer)
+ assert isinstance(model.discriminator_optimizer, k.optimizers.Optimizer)
assert isinstance(model.generator_upds_per_batch, int)
assert isinstance(model.discriminator_upds_per_batch, int)
- assert isinstance(model.referee_optimizer, keras.optimizers.Optimizer)
+ assert isinstance(model.referee_optimizer, k.optimizers.Optimizer)
assert isinstance(model.referee_upds_per_batch, int)
+ if not build_first:
+ model(x, y) # to build the model
+ if metrics is None:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+ else:
+ assert len(model.metrics) == 4 # losses + wass_dist
+ else:
+ assert len(model.metrics) == 3 # g_loss, d_loss, r_loss
+
@pytest.mark.parametrize("referee", [ref, None])
@pytest.mark.parametrize("sample_weight", [w, None])
@pytest.mark.parametrize("lipschitz_penalty_strategy", ["two-sided", "one-sided"])
-def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
+@pytest.mark.parametrize("build_first", [True, False])
+def test_model_train(referee, sample_weight, lipschitz_penalty_strategy, build_first):
from pidgan.algorithms import WGAN_GP
if sample_weight is not None:
slices = (x, y, w)
else:
slices = (x, y)
- dataset = (
+ train_ds = (
tf.data.Dataset.from_tensor_slices(slices)
- .batch(batch_size=512, drop_remainder=True)
- .cache()
- .prefetch(tf.data.AUTOTUNE)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
+ )
+ val_ds = (
+ tf.data.Dataset.from_tensor_slices(slices)
+ .shuffle(buffer_size=int(0.1 * CHUNK_SIZE))
+ .batch(batch_size=BATCH_SIZE)
)
model = WGAN_GP(
@@ -157,28 +193,60 @@ def test_model_train(referee, sample_weight, lipschitz_penalty_strategy):
feature_matching_penalty=1.0,
referee=referee,
)
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- model.compile(
- metrics=None,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt,
- generator_upds_per_batch=1,
- discriminator_upds_per_batch=1,
- referee_optimizer=r_opt,
- referee_upds_per_batch=1,
+ if build_first:
+ model(x, y) # to build the model
+
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model.compile(
+ metrics=["wass_dist"],
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt,
+ generator_upds_per_batch=1,
+ discriminator_upds_per_batch=1,
+ referee_optimizer=r_opt,
+ referee_upds_per_batch=1,
+ )
+ if not build_first:
+ model(x, y) # to build the model
+
+ train = model.fit(
+ train_ds.take(BATCH_SIZE),
+ epochs=2,
+ validation_data=val_ds.take(BATCH_SIZE),
)
- model.fit(dataset, epochs=2)
+ states = train.history.keys()
+ if not build_first:
+ if referee is not None:
+ assert len(states) == 8 # 2x (g_loss + d_loss + r_loss + wass_dist)
+ else:
+ assert len(states) == 6 # 2x (g_loss + d_loss + wass_dist)
+ else:
+ if referee is not None:
+ assert len(states) == 6 # 2x (g_loss + d_loss + r_loss)
+ else:
+ assert len(states) == 4 # 2x (g_loss + d_loss)
+
+ for s in states:
+ for entry in train.history[s]:
+ print(train.history)
+ print(f"{s}: {entry}")
+ assert isinstance(entry, (int, float))
+
+@pytest.mark.parametrize("metrics", [["wass_dist"], [Wass_dist()], None])
@pytest.mark.parametrize("sample_weight", [w, None])
-def test_model_eval(model, sample_weight):
- g_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- d_opt = keras.optimizers.RMSprop(learning_rate=0.001)
- r_opt = keras.optimizers.RMSprop(learning_rate=0.001)
+def test_model_eval(model, metrics, sample_weight):
+ g_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ d_opt = k.optimizers.RMSprop(learning_rate=0.001)
+ r_opt = k.optimizers.RMSprop(learning_rate=0.001)
model.compile(
- metrics=None,
+ metrics=metrics,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt,
generator_upds_per_batch=1,
diff --git a/tests/callbacks/schedulers/test_LearnRateBaseScheduler.py b/tests/callbacks/schedulers/test_LearnRateBaseScheduler.py
index 8c26a73..4bc0be0 100644
--- a/tests/callbacks/schedulers/test_LearnRateBaseScheduler.py
+++ b/tests/callbacks/schedulers/test_LearnRateBaseScheduler.py
@@ -1,8 +1,11 @@
-import numpy as np
import pytest
-from tensorflow import keras
+import keras as k
+import numpy as np
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
+EPOCHS = 5
+LEARN_RATE = 0.001
X = np.c_[
np.random.uniform(-1, 1, size=CHUNK_SIZE),
@@ -11,22 +14,21 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]
-model = keras.Sequential()
-model.add(keras.layers.InputLayer(input_shape=(3,)))
+model = k.Sequential()
+try:
+ model.add(k.layers.InputLayer(shape=(3,)))
+except ValueError:
+ model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
- model.add(keras.layers.Dense(units, activation="relu"))
-model.add(keras.layers.Dense(1))
-
-adam = keras.optimizers.Adam(learning_rate=0.001)
-mse = keras.losses.MeanSquaredError()
+ model.add(k.layers.Dense(units, activation="relu"))
+model.add(k.layers.Dense(1))
@pytest.fixture
def scheduler():
- from pidgan.callbacks.schedulers.LearnRateBaseScheduler import (
- LearnRateBaseScheduler,
- )
+ from pidgan.callbacks.schedulers import LearnRateBaseScheduler
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRateBaseScheduler(optimizer=adam, verbose=True, key="lr")
return sched
@@ -35,19 +37,17 @@ def scheduler():
def test_sched_configuration(scheduler):
- from pidgan.callbacks.schedulers.LearnRateBaseScheduler import (
- LearnRateBaseScheduler,
- )
+ from pidgan.callbacks.schedulers import LearnRateBaseScheduler
assert isinstance(scheduler, LearnRateBaseScheduler)
assert isinstance(scheduler.name, str)
- assert isinstance(scheduler.optimizer, keras.optimizers.Optimizer)
+ assert isinstance(scheduler.optimizer, k.optimizers.Optimizer)
assert isinstance(scheduler.verbose, bool)
assert isinstance(scheduler.key, str)
def test_sched_use(scheduler):
- model.compile(optimizer=adam, loss=mse)
- history = model.fit(X, Y, batch_size=512, epochs=10, callbacks=[scheduler])
- last_lr = float(f"{history.history['lr'][-1]:.3f}")
- assert last_lr == 0.001
+ model.compile(optimizer=scheduler.optimizer, loss=k.losses.MeanSquaredError())
+ train = model.fit(X, Y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[scheduler])
+ last_lr = float(f"{train.history['lr'][-1]:.8f}")
+ assert last_lr == LEARN_RATE
diff --git a/tests/callbacks/schedulers/test_LearnRateCosineDecay.py b/tests/callbacks/schedulers/test_LearnRateCosineDecay.py
index f5ef6ef..d64008a 100644
--- a/tests/callbacks/schedulers/test_LearnRateCosineDecay.py
+++ b/tests/callbacks/schedulers/test_LearnRateCosineDecay.py
@@ -1,8 +1,13 @@
-import numpy as np
import pytest
-from tensorflow import keras
+import keras as k
+import numpy as np
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
+EPOCHS = 5
+LEARN_RATE = 0.001
+MIN_LEARN_RATE = 0.0005
+ALPHA = 0.1
X = np.c_[
np.random.uniform(-1, 1, size=CHUNK_SIZE),
@@ -11,25 +16,26 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]
-model = keras.Sequential()
-model.add(keras.layers.InputLayer(input_shape=(3,)))
+model = k.Sequential()
+try:
+ model.add(k.layers.InputLayer(shape=(3,)))
+except ValueError:
+ model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
- model.add(keras.layers.Dense(units, activation="relu"))
-model.add(keras.layers.Dense(1))
-
-adam = keras.optimizers.Adam(learning_rate=0.001)
-mse = keras.losses.MeanSquaredError()
+ model.add(k.layers.Dense(units, activation="relu"))
+model.add(k.layers.Dense(1))
@pytest.fixture
def scheduler():
from pidgan.callbacks.schedulers import LearnRateCosineDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRateCosineDecay(
optimizer=adam,
- decay_steps=1000,
- alpha=0.95,
- min_learning_rate=0.001,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
+ alpha=ALPHA,
+ min_learning_rate=LEARN_RATE,
verbose=True,
key="lr",
)
@@ -44,7 +50,7 @@ def test_sched_configuration(scheduler):
assert isinstance(scheduler, LearnRateCosineDecay)
assert isinstance(scheduler.name, str)
- assert isinstance(scheduler.optimizer, keras.optimizers.Optimizer)
+ assert isinstance(scheduler.optimizer, k.optimizers.Optimizer)
assert isinstance(scheduler.decay_steps, int)
assert isinstance(scheduler.alpha, float)
assert isinstance(scheduler.min_learning_rate, float)
@@ -52,21 +58,22 @@ def test_sched_configuration(scheduler):
assert isinstance(scheduler.key, str)
-@pytest.mark.parametrize("min_learning_rate", [None, 0.0005])
+@pytest.mark.parametrize("min_learning_rate", [None, MIN_LEARN_RATE])
def test_sched_use(min_learning_rate):
from pidgan.callbacks.schedulers import LearnRateCosineDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
scheduler = LearnRateCosineDecay(
optimizer=adam,
- decay_steps=1000,
- alpha=0.95,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
+ alpha=ALPHA,
min_learning_rate=min_learning_rate,
verbose=True,
)
- model.compile(optimizer=adam, loss=mse)
- history = model.fit(X, Y, batch_size=500, epochs=5, callbacks=[scheduler])
- last_lr = float(f"{history.history['lr'][-1]:.4f}")
+ model.compile(optimizer=adam, loss=k.losses.MeanSquaredError())
+ train = model.fit(X, Y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[scheduler])
+ last_lr = float(f"{train.history['lr'][-1]:.8f}")
if min_learning_rate is not None:
- assert last_lr == 0.0005
+ assert last_lr == MIN_LEARN_RATE
else:
- assert last_lr == 0.0001
+ assert last_lr == ALPHA * LEARN_RATE
diff --git a/tests/callbacks/schedulers/test_LearnRateExpDecay.py b/tests/callbacks/schedulers/test_LearnRateExpDecay.py
index a734b35..5ee7c70 100644
--- a/tests/callbacks/schedulers/test_LearnRateExpDecay.py
+++ b/tests/callbacks/schedulers/test_LearnRateExpDecay.py
@@ -1,8 +1,13 @@
-import numpy as np
import pytest
-from tensorflow import keras
+import keras as k
+import numpy as np
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
+EPOCHS = 5
+LEARN_RATE = 0.001
+MIN_LEARN_RATE = 0.0005
+ALPHA = 0.1
X = np.c_[
np.random.uniform(-1, 1, size=CHUNK_SIZE),
@@ -11,26 +16,27 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]
-model = keras.Sequential()
-model.add(keras.layers.InputLayer(input_shape=(3,)))
+model = k.Sequential()
+try:
+ model.add(k.layers.InputLayer(shape=(3,)))
+except ValueError:
+ model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
- model.add(keras.layers.Dense(units, activation="relu"))
-model.add(keras.layers.Dense(1))
-
-adam = keras.optimizers.Adam(learning_rate=0.001)
-mse = keras.losses.MeanSquaredError()
+ model.add(k.layers.Dense(units, activation="relu"))
+model.add(k.layers.Dense(1))
@pytest.fixture
def scheduler(staircase=False):
from pidgan.callbacks.schedulers import LearnRateExpDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRateExpDecay(
optimizer=adam,
- decay_rate=0.9,
- decay_steps=1000,
+ decay_rate=ALPHA,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
staircase=staircase,
- min_learning_rate=0.001,
+ min_learning_rate=LEARN_RATE,
verbose=False,
key="lr",
)
@@ -45,7 +51,7 @@ def test_sched_configuration(scheduler):
assert isinstance(scheduler, LearnRateExpDecay)
assert isinstance(scheduler.name, str)
- assert isinstance(scheduler.optimizer, keras.optimizers.Optimizer)
+ assert isinstance(scheduler.optimizer, k.optimizers.Optimizer)
assert isinstance(scheduler.decay_rate, float)
assert isinstance(scheduler.decay_steps, int)
assert isinstance(scheduler.staircase, bool)
@@ -55,22 +61,23 @@ def test_sched_configuration(scheduler):
@pytest.mark.parametrize("staircase", [False, True])
-@pytest.mark.parametrize("min_learning_rate", [None, 0.0005])
+@pytest.mark.parametrize("min_learning_rate", [None, MIN_LEARN_RATE])
def test_sched_use(staircase, min_learning_rate):
from pidgan.callbacks.schedulers import LearnRateExpDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRateExpDecay(
optimizer=adam,
- decay_rate=0.1,
- decay_steps=100,
+ decay_rate=ALPHA,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
staircase=staircase,
min_learning_rate=min_learning_rate,
verbose=True,
)
- model.compile(optimizer=adam, loss=mse)
- history = model.fit(X, Y, batch_size=500, epochs=5, callbacks=[sched])
- last_lr = float(f"{history.history['lr'][-1]:.4f}")
+ model.compile(optimizer=adam, loss=k.losses.MeanSquaredError())
+ train = model.fit(X, Y, batch_size=BATCH_SIZE, epochs=5, callbacks=[sched])
+ last_lr = float(f"{train.history['lr'][-1]:.8f}")
if min_learning_rate is not None:
- assert last_lr == 0.0005
+ assert last_lr == MIN_LEARN_RATE
else:
- assert last_lr == 0.0001
+ assert last_lr == ALPHA * LEARN_RATE
diff --git a/tests/callbacks/schedulers/test_LearnRateInvTimeDecay.py b/tests/callbacks/schedulers/test_LearnRateInvTimeDecay.py
index 060f65e..3f224f5 100644
--- a/tests/callbacks/schedulers/test_LearnRateInvTimeDecay.py
+++ b/tests/callbacks/schedulers/test_LearnRateInvTimeDecay.py
@@ -1,8 +1,13 @@
-import numpy as np
import pytest
-from tensorflow import keras
+import keras as k
+import numpy as np
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
+EPOCHS = 5
+LEARN_RATE = 0.001
+MIN_LEARN_RATE = 0.0005
+ALPHA = 0.1
X = np.c_[
np.random.uniform(-1, 1, size=CHUNK_SIZE),
@@ -11,26 +16,27 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]
-model = keras.Sequential()
-model.add(keras.layers.InputLayer(input_shape=(3,)))
+model = k.Sequential()
+try:
+ model.add(k.layers.InputLayer(shape=(3,)))
+except ValueError:
+ model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
- model.add(keras.layers.Dense(units, activation="relu"))
-model.add(keras.layers.Dense(1))
-
-adam = keras.optimizers.Adam(learning_rate=0.001)
-mse = keras.losses.MeanSquaredError()
+ model.add(k.layers.Dense(units, activation="relu"))
+model.add(k.layers.Dense(1))
@pytest.fixture
def scheduler(staircase=False):
from pidgan.callbacks.schedulers import LearnRateInvTimeDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRateInvTimeDecay(
optimizer=adam,
- decay_rate=0.9,
- decay_steps=1000,
+ decay_rate=1 / ALPHA - 1,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
staircase=staircase,
- min_learning_rate=0.001,
+ min_learning_rate=LEARN_RATE,
verbose=False,
key="lr",
)
@@ -45,7 +51,7 @@ def test_sched_configuration(scheduler):
assert isinstance(scheduler, LearnRateInvTimeDecay)
assert isinstance(scheduler.name, str)
- assert isinstance(scheduler.optimizer, keras.optimizers.Optimizer)
+ assert isinstance(scheduler.optimizer, k.optimizers.Optimizer)
assert isinstance(scheduler.decay_rate, float)
assert isinstance(scheduler.decay_steps, int)
assert isinstance(scheduler.staircase, bool)
@@ -55,22 +61,23 @@ def test_sched_configuration(scheduler):
@pytest.mark.parametrize("staircase", [False, True])
-@pytest.mark.parametrize("min_learning_rate", [None, 0.0005])
+@pytest.mark.parametrize("min_learning_rate", [None, MIN_LEARN_RATE])
def test_sched_use(staircase, min_learning_rate):
from pidgan.callbacks.schedulers import LearnRateInvTimeDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRateInvTimeDecay(
optimizer=adam,
- decay_rate=9,
- decay_steps=100,
+ decay_rate=1 / ALPHA - 1,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
staircase=staircase,
min_learning_rate=min_learning_rate,
verbose=True,
)
- model.compile(optimizer=adam, loss=mse)
- history = model.fit(X, Y, batch_size=500, epochs=5, callbacks=[sched])
- last_lr = float(f"{history.history['lr'][-1]:.4f}")
+ model.compile(optimizer=adam, loss=k.losses.MeanSquaredError())
+ train = model.fit(X, Y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[sched])
+ last_lr = float(f"{train.history['lr'][-1]:.8f}")
if min_learning_rate is not None:
- assert last_lr == 0.0005
+ assert last_lr == MIN_LEARN_RATE
else:
- assert last_lr == 0.0001
+ assert last_lr == ALPHA * LEARN_RATE
diff --git a/tests/callbacks/schedulers/test_LearnRatePiecewiseConstDecay.py b/tests/callbacks/schedulers/test_LearnRatePiecewiseConstDecay.py
index 3f8395f..275bd51 100644
--- a/tests/callbacks/schedulers/test_LearnRatePiecewiseConstDecay.py
+++ b/tests/callbacks/schedulers/test_LearnRatePiecewiseConstDecay.py
@@ -1,8 +1,12 @@
-import numpy as np
import pytest
-from tensorflow import keras
+import keras as k
+import numpy as np
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
+EPOCHS = 5
+LEARN_RATE = 0.001
+ALPHA = 0.1
X = np.c_[
np.random.uniform(-1, 1, size=CHUNK_SIZE),
@@ -11,24 +15,25 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]
-model = keras.Sequential()
-model.add(keras.layers.InputLayer(input_shape=(3,)))
+model = k.Sequential()
+try:
+ model.add(k.layers.InputLayer(shape=(3,)))
+except ValueError:
+ model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
- model.add(keras.layers.Dense(units, activation="relu"))
-model.add(keras.layers.Dense(1))
-
-adam = keras.optimizers.Adam(learning_rate=0.001)
-mse = keras.losses.MeanSquaredError()
+ model.add(k.layers.Dense(units, activation="relu"))
+model.add(k.layers.Dense(1))
@pytest.fixture
def scheduler():
from pidgan.callbacks.schedulers import LearnRatePiecewiseConstDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRatePiecewiseConstDecay(
optimizer=adam,
boundaries=[25, 50],
- values=[0.001, 0.0005, 0.0001],
+ values=[LEARN_RATE, 5.0 * ALPHA * LEARN_RATE, ALPHA * LEARN_RATE],
verbose=True,
key="lr",
)
@@ -43,7 +48,7 @@ def test_sched_configuration(scheduler):
assert isinstance(scheduler, LearnRatePiecewiseConstDecay)
assert isinstance(scheduler.name, str)
- assert isinstance(scheduler.optimizer, keras.optimizers.Optimizer)
+ assert isinstance(scheduler.optimizer, k.optimizers.Optimizer)
assert isinstance(scheduler.boundaries, list)
assert isinstance(scheduler.boundaries[0], int)
assert isinstance(scheduler.values, list)
@@ -54,7 +59,7 @@ def test_sched_configuration(scheduler):
def test_sched_use(scheduler):
- model.compile(optimizer=adam, loss=mse)
- history = model.fit(X, Y, batch_size=500, epochs=5, callbacks=[scheduler])
- last_lr = float(f"{history.history['lr'][-1]:.4f}")
- assert last_lr == 0.0001
+ model.compile(optimizer=scheduler.optimizer, loss=k.losses.MeanSquaredError())
+ train = model.fit(X, Y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[scheduler])
+ last_lr = float(f"{train.history['lr'][-1]:.8f}")
+ assert last_lr == ALPHA * LEARN_RATE
diff --git a/tests/callbacks/schedulers/test_LearnRatePolynomialDecay.py b/tests/callbacks/schedulers/test_LearnRatePolynomialDecay.py
index f932327..aa1cdd9 100644
--- a/tests/callbacks/schedulers/test_LearnRatePolynomialDecay.py
+++ b/tests/callbacks/schedulers/test_LearnRatePolynomialDecay.py
@@ -1,8 +1,12 @@
-import numpy as np
import pytest
-from tensorflow import keras
+import keras as k
+import numpy as np
CHUNK_SIZE = int(1e4)
+BATCH_SIZE = 500
+EPOCHS = 5
+LEARN_RATE = 0.001
+ALPHA = 0.1
X = np.c_[
np.random.uniform(-1, 1, size=CHUNK_SIZE),
@@ -11,24 +15,25 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]
-model = keras.Sequential()
-model.add(keras.layers.InputLayer(input_shape=(3,)))
+model = k.Sequential()
+try:
+ model.add(k.layers.InputLayer(shape=(3,)))
+except ValueError:
+ model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
- model.add(keras.layers.Dense(units, activation="relu"))
-model.add(keras.layers.Dense(1))
-
-adam = keras.optimizers.Adam(learning_rate=0.001)
-mse = keras.losses.MeanSquaredError()
+ model.add(k.layers.Dense(units, activation="relu"))
+model.add(k.layers.Dense(1))
@pytest.fixture
def scheduler(cycle=False):
from pidgan.callbacks.schedulers import LearnRatePolynomialDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRatePolynomialDecay(
optimizer=adam,
- decay_steps=1000,
- end_learning_rate=0.0001,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
+ end_learning_rate=ALPHA * LEARN_RATE,
power=1.0,
cycle=cycle,
verbose=False,
@@ -45,7 +50,7 @@ def test_sched_configuration(scheduler):
assert isinstance(scheduler, LearnRatePolynomialDecay)
assert isinstance(scheduler.name, str)
- assert isinstance(scheduler.optimizer, keras.optimizers.Optimizer)
+ assert isinstance(scheduler.optimizer, k.optimizers.Optimizer)
assert isinstance(scheduler.decay_steps, int)
assert isinstance(scheduler.end_learning_rate, float)
assert isinstance(scheduler.power, float)
@@ -58,15 +63,16 @@ def test_sched_configuration(scheduler):
def test_sched_use(cycle):
from pidgan.callbacks.schedulers import LearnRatePolynomialDecay
+ adam = k.optimizers.Adam(learning_rate=LEARN_RATE)
sched = LearnRatePolynomialDecay(
optimizer=adam,
- decay_steps=100,
- end_learning_rate=0.0001,
+ decay_steps=CHUNK_SIZE / BATCH_SIZE * EPOCHS,
+ end_learning_rate=ALPHA * LEARN_RATE,
power=1.0,
cycle=cycle,
verbose=True,
)
- model.compile(optimizer=adam, loss=mse)
- history = model.fit(X, Y, batch_size=500, epochs=5, callbacks=[sched])
- last_lr = float(f"{history.history['lr'][-1]:.4f}")
- assert last_lr == 0.0001
+ model.compile(optimizer=adam, loss=k.losses.MeanSquaredError())
+ train = model.fit(X, Y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[sched])
+ last_lr = float(f"{train.history['lr'][-1]:.8f}")
+ assert last_lr == 0.1 * LEARN_RATE
diff --git a/tests/metrics/test_Accuracy_metric.py b/tests/metrics/test_Accuracy_metric.py
index d1d4f4f..632719f 100644
--- a/tests/metrics/test_Accuracy_metric.py
+++ b/tests/metrics/test_Accuracy_metric.py
@@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)
-def test_metric_use_no_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-def test_metric_use_with_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=weight)
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(metric, sample_weight):
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/metrics/test_BinaryCrossentropy_metric.py b/tests/metrics/test_BinaryCrossentropy_metric.py
index 0f4de1c..ca0880e 100644
--- a/tests/metrics/test_BinaryCrossentropy_metric.py
+++ b/tests/metrics/test_BinaryCrossentropy_metric.py
@@ -27,28 +27,15 @@ def test_metric_configuration(metric):
@pytest.mark.parametrize("from_logits", [False, True])
-def test_metric_use_no_weights(from_logits):
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(from_logits, sample_weight):
from pidgan.metrics import BinaryCrossentropy
metric = BinaryCrossentropy(from_logits=from_logits, label_smoothing=0.0)
if from_logits:
- metric.update_state(y_true, y_pred_logits, sample_weight=None)
+ metric.update_state(y_true, y_pred_logits, sample_weight=sample_weight)
res = metric.result().numpy()
else:
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-@pytest.mark.parametrize("from_logits", [False, True])
-def test_metric_use_with_weights(from_logits):
- from pidgan.metrics import BinaryCrossentropy
-
- metric = BinaryCrossentropy(from_logits=from_logits, label_smoothing=0.0)
- if from_logits:
- metric.update_state(y_true, y_pred_logits, sample_weight=weight)
- res = metric.result().numpy()
- else:
- metric.update_state(y_true, y_pred, sample_weight=weight)
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/metrics/test_JSDivergence_metric.py b/tests/metrics/test_JSDivergence_metric.py
index c737811..f8f0343 100644
--- a/tests/metrics/test_JSDivergence_metric.py
+++ b/tests/metrics/test_JSDivergence_metric.py
@@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)
-def test_metric_use_no_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-def test_metric_use_with_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=weight)
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(metric, sample_weight):
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/metrics/test_KLDivergence_metric.py b/tests/metrics/test_KLDivergence_metric.py
index 48a4701..3c3a9ee 100644
--- a/tests/metrics/test_KLDivergence_metric.py
+++ b/tests/metrics/test_KLDivergence_metric.py
@@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)
-def test_metric_use_no_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-def test_metric_use_with_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=weight)
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(metric, sample_weight):
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/metrics/test_MeanAbsoluteError_metric.py b/tests/metrics/test_MeanAbsoluteError_metric.py
index 166fd7d..dc27e09 100644
--- a/tests/metrics/test_MeanAbsoluteError_metric.py
+++ b/tests/metrics/test_MeanAbsoluteError_metric.py
@@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)
-def test_metric_use_no_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-def test_metric_use_with_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=weight)
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(metric, sample_weight):
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/metrics/test_MeanSquaredError_metric.py b/tests/metrics/test_MeanSquaredError_metric.py
index 8673fb8..25c76fa 100644
--- a/tests/metrics/test_MeanSquaredError_metric.py
+++ b/tests/metrics/test_MeanSquaredError_metric.py
@@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)
-def test_metric_use_no_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-def test_metric_use_with_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=weight)
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(metric, sample_weight):
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/metrics/test_RootMeanSquaredError_metric.py b/tests/metrics/test_RootMeanSquaredError_metric.py
index 273d319..16e5794 100644
--- a/tests/metrics/test_RootMeanSquaredError_metric.py
+++ b/tests/metrics/test_RootMeanSquaredError_metric.py
@@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)
-def test_metric_use_no_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-def test_metric_use_with_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=weight)
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(metric, sample_weight):
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/metrics/test_WassersteinDistance_metric.py b/tests/metrics/test_WassersteinDistance_metric.py
index 67b101f..8c6c981 100644
--- a/tests/metrics/test_WassersteinDistance_metric.py
+++ b/tests/metrics/test_WassersteinDistance_metric.py
@@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)
-def test_metric_use_no_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=None)
- res = metric.result().numpy()
- assert res
-
-
-def test_metric_use_with_weights(metric):
- metric.update_state(y_true, y_pred, sample_weight=weight)
+@pytest.mark.parametrize("sample_weight", [None, weight])
+def test_metric_use(metric, sample_weight):
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
diff --git a/tests/optimization/callbacks/test_HopaasPruner.py b/tests/optimization/callbacks/test_HopaasPruner.py
index f4c81e7..7cf3d2b 100644
--- a/tests/optimization/callbacks/test_HopaasPruner.py
+++ b/tests/optimization/callbacks/test_HopaasPruner.py
@@ -1,10 +1,10 @@
import os
import hopaas_client as hpc
+import keras as k
import numpy as np
import pytest
import yaml
-from tensorflow import keras
NUM_TRIALS = 1
CHUNK_SIZE = int(1e3)
@@ -24,11 +24,14 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]
-model = keras.Sequential()
-model.add(keras.layers.InputLayer(input_shape=(3,)))
+model = k.Sequential()
+try:
+ model.add(k.layers.InputLayer(shape=(3,)))
+except ValueError:
+ model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
- model.add(keras.layers.Dense(units, activation="relu"))
-model.add(keras.layers.Dense(1))
+ model.add(k.layers.Dense(units, activation="relu"))
+model.add(k.layers.Dense(1))
@pytest.fixture
@@ -79,8 +82,8 @@ def test_callback_use(enable_pruning):
for _ in range(NUM_TRIALS):
with study.trial() as trial:
- adam = keras.optimizers.Adam(learning_rate=trial.learning_rate)
- mse = keras.losses.MeanSquaredError()
+ adam = k.optimizers.Adam(learning_rate=trial.learning_rate)
+ mse = k.losses.MeanSquaredError()
report = HopaasPruner(
trial=trial,
diff --git a/tests/players/classifiers/test_AuxClassifier.py b/tests/players/classifiers/test_AuxClassifier.py
index 7fb6488..806fe9b 100644
--- a/tests/players/classifiers/test_AuxClassifier.py
+++ b/tests/players/classifiers/test_AuxClassifier.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -63,11 +63,13 @@ def test_model_use(mlp_hidden_activation, enable_res_blocks, inputs):
mlp_dropout_rates=0.0,
enable_residual_blocks=enable_res_blocks,
)
+
+ print(inputs)
out = model(inputs)
model.summary()
test_shape = [x.shape[0], 1]
assert out.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Model)
+ assert isinstance(model.plain_keras, k.Model)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@@ -83,26 +85,37 @@ def test_model_train(model, inputs, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, inputs, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
- model.evaluate(inputs, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=inputs, y=labels, sample_weight=sample_weight)
@pytest.mark.parametrize("inputs", [y, (x, y)])
def test_model_export(model, inputs):
out, aux = model(inputs, return_aux_features=True)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
if isinstance(inputs, (list, tuple)):
in_reloaded = tf.concat((x, y, aux), axis=-1)
else:
diff --git a/tests/players/classifiers/test_AuxMultiClassifier.py b/tests/players/classifiers/test_AuxMultiClassifier.py
index d8ed814..1131a9a 100644
--- a/tests/players/classifiers/test_AuxMultiClassifier.py
+++ b/tests/players/classifiers/test_AuxMultiClassifier.py
@@ -2,8 +2,8 @@
import numpy as np
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -66,11 +66,12 @@ def test_model_use(mlp_hidden_activation, enable_res_blocks, inputs):
mlp_dropout_rates=0.0,
enable_residual_blocks=enable_res_blocks,
)
+
out = model(inputs)
model.summary()
test_shape = [x.shape[0], model.num_multiclasses]
assert out.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Model)
+ assert isinstance(model.plain_keras, k.Model)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@@ -86,26 +87,37 @@ def test_model_train(model, inputs, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- cce = keras.losses.CategoricalCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=cce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, inputs, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- cce = keras.losses.CategoricalCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=cce, metrics=["mse"])
- model.evaluate(inputs, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=inputs, y=labels, sample_weight=sample_weight)
@pytest.mark.parametrize("inputs", [y, (x, y)])
def test_model_export(model, inputs):
out, aux = model(inputs, return_aux_features=True)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
if isinstance(inputs, (list, tuple)):
in_reloaded = tf.concat((x, y, aux), axis=-1)
else:
diff --git a/tests/players/classifiers/test_Classifier.py b/tests/players/classifiers/test_Classifier.py
index 3287ac7..8ba1968 100644
--- a/tests/players/classifiers/test_Classifier.py
+++ b/tests/players/classifiers/test_Classifier.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -58,11 +58,12 @@ def test_model_use(mlp_hidden_units, mlp_hidden_activation, mlp_dropout_rates, i
mlp_hidden_activation=mlp_hidden_activation,
mlp_dropout_rates=mlp_dropout_rates,
)
+
out = model(inputs)
model.summary()
test_shape = [x.shape[0], 1]
assert out.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Sequential)
+ assert isinstance(model.plain_keras, k.Sequential)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@@ -78,26 +79,37 @@ def test_model_train(model, inputs, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, inputs, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
- model.evaluate(inputs, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=inputs, y=labels, sample_weight=sample_weight)
@pytest.mark.parametrize("inputs", [y, (x, y)])
def test_model_export(model, inputs):
out = model(inputs)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
if isinstance(inputs, (list, tuple)):
in_reloaded = tf.concat((x, y), axis=-1)
else:
diff --git a/tests/players/classifiers/test_MultiClassifier.py b/tests/players/classifiers/test_MultiClassifier.py
index 1570584..c3c34e5 100644
--- a/tests/players/classifiers/test_MultiClassifier.py
+++ b/tests/players/classifiers/test_MultiClassifier.py
@@ -2,8 +2,8 @@
import numpy as np
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -62,11 +62,12 @@ def test_model_use(mlp_hidden_units, mlp_hidden_activation, mlp_dropout_rates, i
mlp_hidden_activation=mlp_hidden_activation,
mlp_dropout_rates=mlp_dropout_rates,
)
+
out = model(inputs)
model.summary()
test_shape = [x.shape[0], model.num_multiclasses]
assert out.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Sequential)
+ assert isinstance(model.plain_keras, k.Sequential)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@@ -82,26 +83,37 @@ def test_model_train(model, inputs, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- cce = keras.losses.CategoricalCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=cce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, inputs, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- cce = keras.losses.CategoricalCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=cce, metrics=["mse"])
- model.evaluate(inputs, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=inputs, y=labels, sample_weight=sample_weight)
@pytest.mark.parametrize("inputs", [y, (x, y)])
def test_model_export(model, inputs):
out = model(inputs)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
if isinstance(inputs, (list, tuple)):
in_reloaded = tf.concat((x, y), axis=-1)
else:
diff --git a/tests/players/classifiers/test_ResClassifier.py b/tests/players/classifiers/test_ResClassifier.py
index 444f9e1..a902a92 100644
--- a/tests/players/classifiers/test_ResClassifier.py
+++ b/tests/players/classifiers/test_ResClassifier.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -56,11 +56,12 @@ def test_model_use(mlp_hidden_activation, inputs):
mlp_hidden_activation=mlp_hidden_activation,
mlp_dropout_rates=0.0,
)
+
out = model(inputs)
model.summary()
test_shape = [x.shape[0], 1]
assert out.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Model)
+ assert isinstance(model.plain_keras, k.Model)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@@ -76,26 +77,37 @@ def test_model_train(model, inputs, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, inputs, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
- model.evaluate(inputs, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=inputs, y=labels, sample_weight=sample_weight)
@pytest.mark.parametrize("inputs", [y, (x, y)])
def test_model_export(model, inputs):
out = model(inputs)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
if isinstance(inputs, (list, tuple)):
in_reloaded = tf.concat((x, y), axis=-1)
else:
diff --git a/tests/players/classifiers/test_ResMultiClassifier.py b/tests/players/classifiers/test_ResMultiClassifier.py
index b02989e..6e508f1 100644
--- a/tests/players/classifiers/test_ResMultiClassifier.py
+++ b/tests/players/classifiers/test_ResMultiClassifier.py
@@ -2,8 +2,8 @@
import numpy as np
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -60,11 +60,12 @@ def test_model_use(mlp_hidden_activation, inputs):
mlp_hidden_activation=mlp_hidden_activation,
mlp_dropout_rates=0.0,
)
+
out = model(inputs)
model.summary()
test_shape = [x.shape[0], model.num_multiclasses]
assert out.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Model)
+ assert isinstance(model.plain_keras, k.Model)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@@ -80,26 +81,37 @@ def test_model_train(model, inputs, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- cce = keras.losses.CategoricalCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=cce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("inputs", [y, (x, y)])
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, inputs, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- cce = keras.losses.CategoricalCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=cce, metrics=["mse"])
- model.evaluate(inputs, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=inputs, y=labels, sample_weight=sample_weight)
@pytest.mark.parametrize("inputs", [y, (x, y)])
def test_model_export(model, inputs):
out = model(inputs)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
if isinstance(inputs, (list, tuple)):
in_reloaded = tf.concat((x, y), axis=-1)
else:
diff --git a/tests/players/discriminators/test_AuxDiscriminator.py b/tests/players/discriminators/test_AuxDiscriminator.py
index 9f7c2b3..aec996f 100644
--- a/tests/players/discriminators/test_AuxDiscriminator.py
+++ b/tests/players/discriminators/test_AuxDiscriminator.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -63,6 +63,7 @@ def test_model_use(enable_res_blocks, output_activation):
enable_residual_blocks=enable_res_blocks,
output_activation=output_activation,
)
+
out = model((x, y))
model.summary()
test_shape = [x.shape[0]]
@@ -72,7 +73,7 @@ def test_model_use(enable_res_blocks, output_activation):
test_shape = [x.shape[0]]
test_shape.append(model.mlp_hidden_units)
assert hidden_feat.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Model)
+ assert isinstance(model.plain_keras, k.Model)
@pytest.mark.parametrize("sample_weight", [w, None])
@@ -87,24 +88,35 @@ def test_model_train(model, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
- model.evaluate((x, y), sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=(x, y), y=labels, sample_weight=sample_weight)
def test_model_export(model):
out, aux = model((x, y), return_aux_features=True)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
in_reloaded = tf.concat((x, y, aux), axis=-1)
out_reloaded = model_reloaded(in_reloaded)
comparison = out.numpy() == out_reloaded.numpy()
diff --git a/tests/players/discriminators/test_Discriminator.py b/tests/players/discriminators/test_Discriminator.py
index 82f7949..1b9fc42 100644
--- a/tests/players/discriminators/test_Discriminator.py
+++ b/tests/players/discriminators/test_Discriminator.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -58,6 +58,7 @@ def test_model_use(mlp_hidden_units, mlp_dropout_rates, output_activation):
mlp_dropout_rates=mlp_dropout_rates,
output_activation=output_activation,
)
+
out = model((x, y))
model.summary()
test_shape = [x.shape[0]]
@@ -67,7 +68,7 @@ def test_model_use(mlp_hidden_units, mlp_dropout_rates, output_activation):
test_shape = [x.shape[0]]
test_shape.append(model.mlp_hidden_units[hidden_idx])
assert hidden_feat.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Sequential)
+ assert isinstance(model.plain_keras, k.Sequential)
@pytest.mark.parametrize("sample_weight", [w, None])
@@ -82,24 +83,35 @@ def test_model_train(model, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
- model.evaluate((x, y), sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=(x, y), y=labels, sample_weight=sample_weight)
def test_model_export(model):
out = model((x, y))
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
in_reloaded = tf.concat((x, y), axis=-1)
out_reloaded = model_reloaded(in_reloaded)
comparison = out.numpy() == out_reloaded.numpy()
diff --git a/tests/players/discriminators/test_ResDiscriminator.py b/tests/players/discriminators/test_ResDiscriminator.py
index 5d82c7d..f131a31 100644
--- a/tests/players/discriminators/test_ResDiscriminator.py
+++ b/tests/players/discriminators/test_ResDiscriminator.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -56,6 +56,7 @@ def test_model_use(output_activation):
mlp_dropout_rates=0.0,
output_activation=output_activation,
)
+
out = model((x, y))
model.summary()
test_shape = [x.shape[0]]
@@ -65,7 +66,7 @@ def test_model_use(output_activation):
test_shape = [x.shape[0]]
test_shape.append(model.mlp_hidden_units)
assert hidden_feat.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Model)
+ assert isinstance(model.plain_keras, k.Model)
@pytest.mark.parametrize("sample_weight", [w, None])
@@ -80,24 +81,35 @@ def test_model_train(model, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- bce = keras.losses.BinaryCrossentropy(from_logits=False)
- model.compile(optimizer=adam, loss=bce, metrics=["mse"])
- model.evaluate((x, y), sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x=(x, y), y=labels, sample_weight=sample_weight)
def test_model_export(model):
out = model((x, y))
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
in_reloaded = tf.concat((x, y), axis=-1)
out_reloaded = model_reloaded(in_reloaded)
comparison = out.numpy() == out_reloaded.numpy()
diff --git a/tests/players/generators/test_Generator.py b/tests/players/generators/test_Generator.py
index 7b7c493..5380d16 100644
--- a/tests/players/generators/test_Generator.py
+++ b/tests/players/generators/test_Generator.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -59,12 +59,13 @@ def test_model_use(mlp_hidden_units, mlp_dropout_rates, output_activation):
mlp_dropout_rates=mlp_dropout_rates,
output_activation=output_activation,
)
- output = model(x)
+
+ out = model(x)
model.summary()
test_shape = [x.shape[0]]
test_shape.append(model.output_dim)
- assert output.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Sequential)
+ assert out.shape == tuple(test_shape)
+ assert isinstance(model.plain_keras, k.Sequential)
@pytest.mark.parametrize("sample_weight", [w, None])
@@ -79,18 +80,22 @@ def test_model_train(model, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- mse = keras.losses.MeanSquaredError()
- model.compile(optimizer=adam, loss=mse, metrics=["mae"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- mse = keras.losses.MeanSquaredError()
- model.compile(optimizer=adam, loss=mse, metrics=["mae"])
- model.evaluate(x, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x, y, sample_weight=sample_weight)
def test_model_generate(model):
@@ -106,8 +111,15 @@ def test_model_generate(model):
def test_model_export(model):
out, latent_sample = model.generate(x, return_latent_sample=True)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
x_reloaded = tf.concat([x, latent_sample], axis=-1)
out_reloaded = model_reloaded(x_reloaded)
comparison = out.numpy() == out_reloaded.numpy()
diff --git a/tests/players/generators/test_ResGenerator.py b/tests/players/generators/test_ResGenerator.py
index 8c56eb6..4c86823 100644
--- a/tests/players/generators/test_ResGenerator.py
+++ b/tests/players/generators/test_ResGenerator.py
@@ -1,8 +1,8 @@
import os
import pytest
+import keras as k
import tensorflow as tf
-from tensorflow import keras
CHUNK_SIZE = int(1e4)
BATCH_SIZE = 500
@@ -57,12 +57,13 @@ def test_model_use(output_activation):
mlp_dropout_rates=0.0,
output_activation=output_activation,
)
- output = model(x)
+
+ out = model(x)
model.summary()
test_shape = [x.shape[0]]
test_shape.append(model.output_dim)
- assert output.shape == tuple(test_shape)
- assert isinstance(model.export_model, keras.Model)
+ assert out.shape == tuple(test_shape)
+ assert isinstance(model.plain_keras, k.Model)
@pytest.mark.parametrize("sample_weight", [w, None])
@@ -77,18 +78,22 @@ def test_model_train(model, sample_weight):
.cache()
.prefetch(tf.data.AUTOTUNE)
)
- adam = keras.optimizers.Adam(learning_rate=0.001)
- mse = keras.losses.MeanSquaredError()
- model.compile(optimizer=adam, loss=mse, metrics=["mae"])
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
model.fit(dataset, epochs=2)
@pytest.mark.parametrize("sample_weight", [w, None])
def test_model_eval(model, sample_weight):
- adam = keras.optimizers.Adam(learning_rate=0.001)
- mse = keras.losses.MeanSquaredError()
- model.compile(optimizer=adam, loss=mse, metrics=["mae"])
- model.evaluate(x, sample_weight=sample_weight)
+ model.compile(
+ optimizer=k.optimizers.Adam(learning_rate=0.001),
+ loss=k.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ model.evaluate(x, y, sample_weight=sample_weight)
def test_model_generate(model):
@@ -104,8 +109,15 @@ def test_model_generate(model):
def test_model_export(model):
out, latent_sample = model.generate(x, return_latent_sample=True)
- keras.models.save_model(model.export_model, export_dir, save_format="tf")
- model_reloaded = keras.models.load_model(export_dir)
+
+ v_major, v_minor, _ = [int(v) for v in k.__version__.split(".")]
+ if v_major == 3 and v_minor >= 0:
+ model.plain_keras.export(export_dir)
+ model_reloaded = k.layers.TFSMLayer(export_dir, call_endpoint="serve")
+ else:
+ k.models.save_model(model.plain_keras, export_dir, save_format="tf")
+ model_reloaded = k.models.load_model(export_dir)
+
x_reloaded = tf.concat([x, latent_sample], axis=-1)
out_reloaded = model_reloaded(x_reloaded)
comparison = out.numpy() == out_reloaded.numpy()
diff --git a/tests/utils/checks/test_checkMetrics.py b/tests/utils/checks/test_checkMetrics.py
index 7ddf282..ca3a315 100644
--- a/tests/utils/checks/test_checkMetrics.py
+++ b/tests/utils/checks/test_checkMetrics.py
@@ -1,6 +1,6 @@
import pytest
-from pidgan.metrics.BaseMetric import BaseMetric
+from pidgan.metrics import BaseMetric
from pidgan.utils.checks.checkMetrics import METRIC_SHORTCUTS, PIDGAN_METRICS
diff --git a/tests/utils/checks/test_checkOptimizer.py b/tests/utils/checks/test_checkOptimizer.py
index 3a6352c..2b5e7f7 100644
--- a/tests/utils/checks/test_checkOptimizer.py
+++ b/tests/utils/checks/test_checkOptimizer.py
@@ -1,5 +1,5 @@
import pytest
-from tensorflow import keras
+import keras as k
from pidgan.utils.checks.checkOptimizer import OPT_SHORTCUTS, TF_OPTIMIZERS
@@ -11,7 +11,7 @@ def test_checker_use_strings(optimizer):
from pidgan.utils.checks import checkOptimizer
res = checkOptimizer(optimizer)
- assert isinstance(res, keras.optimizers.Optimizer)
+ assert isinstance(res, k.optimizers.Optimizer)
@pytest.mark.parametrize("optimizer", TF_OPTIMIZERS)
@@ -19,4 +19,4 @@ def test_checker_use_classes(optimizer):
from pidgan.utils.checks import checkOptimizer
res = checkOptimizer(optimizer)
- assert isinstance(res, keras.optimizers.Optimizer)
+ assert isinstance(res, k.optimizers.Optimizer)