Skip to content

Commit

Permalink
add wgan
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Jul 11, 2024
1 parent dafc344 commit 2b138cb
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ TSGM implements several generative models for synthetic time series data.
| ------------- | ------------- | ------------- | ------------- |
| Structural Time Series model | [tsgm.models.sts.STS](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.sts.STS) | Data-driven | Great for modeling time series when prior knowledge is available (e.g., trend or seasonality). |
| GAN | [tsgm.models.cgan.GAN](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cgan.GAN) | Data-driven | A generic implementation of GAN for time series generation. It can be customized with architectures for generators and discriminators. |
| WaveGAN | [tsgm.models.cgan.GAN](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cgan.GAN) | Data-driven | WaveGAN is the model for audio synthesis proposed in [Adversarial Audio Synthesis](https://arxiv.org/abs/1802.04208). To use WaveGAN, set `use_wgan=True` when initializing the GAN class and use the `zoo["wavegan"]` architecture from the model zoo. |
| ConditionalGAN | [tsgm.models.cgan.ConditionalGAN](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cgan.ConditionalGAN) | Data-driven | A generic implementation of conditional GAN. It supports scalar conditioning as well as temporal one. |
| BetaVAE | [tsgm.models.cvae.BetaVAE](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cvae.BetaVAE) | Data-driven | A generic implementation of Beta VAE for TS. The loss function is customized to work well with multi-dimensional time series. |
| cBetaVAE | [tsgm.models.cvae.cBetaVAE](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cvae.cBetaVAE) | Data-driven | Conditional version of BetaVAE. It supports temporal a scalar condiotioning.|
Expand Down
32 changes: 28 additions & 4 deletions tests/test_cgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def test_dp_compiler():
learning_rate=learning_rate
)


g_optimizer = tf_privacy.DPKerasAdamOptimizer(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
Expand All @@ -259,6 +258,31 @@ def test_dp_compiler():
assert generated_samples.shape == (10, 64, 1)


def test_temporal_cgan_multiple_features():
# TODO
pass
def test_wavegan():
latent_dim = 2
output_dim = 1
feature_dim = 1
seq_len = 64
batch_size = 48

dataset = _gen_dataset(seq_len, feature_dim, batch_size)
architecture = tsgm.models.architectures.zoo["wavegan"](
seq_len=seq_len, feat_dim=feature_dim,
latent_dim=latent_dim, output_dim=output_dim)
discriminator, generator = architecture.discriminator, architecture.generator
gan = tsgm.models.cgan.GAN(
discriminator=discriminator, generator=generator, latent_dim=latent_dim, use_wgan=True
)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(),
)

gan.fit(dataset, epochs=1)

assert gan.generator is not None
assert gan.discriminator is not None
# Check generation
generated_samples = gan.generate(10)
assert generated_samples.shape == (10, seq_len, 1)
94 changes: 94 additions & 0 deletions tsgm/models/architectures/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,99 @@ def _build_generator(self, output_activation: str) -> keras.Model:
return generator


class WaveGANArchitecture(BaseGANArchitecture):
"""
WaveGAN architecture, from https://arxiv.org/abs/1802.04208
Inherits from BaseGANArchitecture.
"""
arch_type = "gan:raw"

def __init__(self, seq_len: int, feat_dim: int = 64, latent_dim: int = 32, output_dim: int = 1, kernel_size: int = 32, phase_rad: int = 2, use_batchnorm: bool = False):
"""
Initializes the WaveGANArchitecture.
:param seq_len: Length of input sequences.
:type seq_len: int
:param feat_dim: Dimensionality of input features.
:type feat_dim: int
:param latent_dim: Dimensionality of the latent space.
:type latent_dim: int
:param output_dim: Dimensionality of the output.
:type output_dim: int
:param kernel_size: Sizes of convolutions
:type kernel_size: int, optional
:param phase_rad: Phase shuffle radius for wavegan (default is 2)
:type phase_rad: int, optional
:param use_batchnorm: Whether to use batchnorm (default is False)
:type use_batchnorm: bool, optional
"""
self.seq_len = seq_len
self.feat_dim = feat_dim
self.latent_dim = latent_dim
self.kernel_size = kernel_size
self.phase_rad = phase_rad
self.output_dim = output_dim
self.use_batchnorm = use_batchnorm

self._discriminator = self._build_discriminator()
self._generator = self._build_generator()

def _apply_phaseshuffle(self, x, rad):
'''
Based on
https://github.com/chrisdonahue/wavegan/
'''
if rad <= 0 or x.shape[1] <= 1:
return x

b, x_len, nch = x.get_shape().as_list()

phase = tf.random.uniform([], minval=-rad, maxval=rad + 1, dtype=tf.int32)
pad_l, pad_r = tf.maximum(phase, 0), tf.maximum(-phase, 0)
phase_start = pad_r
x = tf.pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode="reflect")

x = x[:, phase_start:phase_start + x_len]
x.set_shape([b, x_len, nch])

return x

def _conv_transpose_block(self, inputs, channels, strides=4):
x = layers.Conv1DTranspose(channels, self.kernel_size, strides=strides, padding='same', use_bias=False)(inputs)
x = layers.BatchNormalization()(x) if self.use_batchnorm else x
x = layers.LeakyReLU()(x)
return x

def _build_generator(self):
inputs = layers.Input((self.latent_dim,))
x = layers.Dense(16 * 1024, use_bias=False)(inputs)
x = layers.BatchNormalization()(x) if self.use_batchnorm else x
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 1024))(x)

for conv_size in [512, 256, 128, 64]:
x = self._conv_transpose_block(x, conv_size)

x = layers.Conv1DTranspose(1, self.kernel_size, strides=4, padding='same', use_bias=False, activation='tanh')(x)
pool_and_stride = math.ceil((x.shape[1] + 1) / (self.seq_len + 1))
x = layers.AveragePooling1D(pool_size=pool_and_stride, strides=pool_and_stride)(x)
return keras.Model(inputs, x)

def _build_discriminator(self):
inputs = layers.Input((self.seq_len, self.feat_dim))
for conv_size in [64, 128, 256, 512]:
x = layers.Conv1D(conv_size, self.kernel_size, strides=4, padding='same')(inputs)
x = layers.BatchNormalization()(x) if self.use_batchnorm else x
x = layers.LeakyReLU()(x)
x = self._apply_phaseshuffle(x, self.phase_rad)

x = layers.Flatten()(x)
x = layers.Dense(1)(x)

return keras.Model(inputs, x)


class Zoo(dict):
"""
A collection of architectures represented. It behaves like supports Python `dict` API.
Expand Down Expand Up @@ -901,6 +994,7 @@ def summary(self) -> None:
"t-cgan_c4": tcGAN_Conv4Architecture,
"cgan_lstm_n": cGAN_LSTMnArchitecture,
"cgan_lstm_3": cGAN_LSTMConv3Architecture,
"wavegan": WaveGANArchitecture,

# Downstream models
"clf_cn": ConvnArchitecture,
Expand Down
60 changes: 53 additions & 7 deletions tsgm/models/cgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,56 @@ class GAN(keras.Model):
"""
GAN implementation for unlabeled time series.
"""
def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int) -> None:
def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, use_wgan: bool = False) -> None:
"""
:param discriminator: A discriminator model which takes a time series as input and check
whether the image is real or fake.
whether the sample is real or fake.
:type discriminator: keras.Model
:param generator: Takes as input a random noise vector of `latent_dim` length and returns
a simulated time-series.
:type generator: keras.Model
:param latent_dim: The size of the noise vector.
:type latent_dim: int
:param use_wgan: Use Wasserstein GAN with gradien penalty
:type use_wgan: bool
"""
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self._seq_len = self.generator.output_shape[1]
self.use_wgan = use_wgan
self.gp_weight = 10.0

self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

def wgan_discriminator_loss(self, real_sample, fake_sample):
real_loss = tf.reduce_mean(real_sample)
fake_loss = tf.reduce_mean(fake_sample)
return fake_loss - real_loss

# Define the loss functions to be used for generator
def wgan_generator_loss(self, fake_sample):
return -tf.reduce_mean(fake_sample)

def gradient_penalty(self, batch_size, real_samples, fake_samples):
# get the interpolated samples
alpha = tf.random.normal([batch_size, 1, 1], 0.0, 1.0)
diff = fake_samples - real_samples
interpolated = real_samples + alpha * diff
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
# 1. Get the discriminator output for this interpolated sample.
pred = self.discriminator(interpolated, training=True)

# 2. Calculate the gradients w.r.t to this interpolated sample.
grads = gp_tape.gradient(pred, [interpolated])[0]
# 3. Calcuate the norm of the gradients
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2]))
gp = tf.reduce_mean((norm - 1.0) ** 2)
return gp

@property
def metrics(self) -> T.List:
"""
Expand Down Expand Up @@ -94,7 +124,6 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]:
"""
real_data = data
batch_size = tf.shape(real_data)[0]

# Generate ts
random_vector = self._get_random_vector_labels(batch_size)
fake_data = self.generator(random_vector)
Expand All @@ -111,7 +140,19 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]:
)
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_data)
d_loss = self.loss_fn(desc_labels, predictions)
if self.use_wgan:
fake_logits = self.discriminator(fake_data, training=True)
# Get the logits for the real samples
real_logits = self.discriminator(real_data, training=True)

# Calculate the discriminator loss using the fake and real sample logits
d_cost = self.wgan_discriminator_loss(real_logits, fake_logits)
# Calculate the gradient penalty
gp = self.gradient_penalty(batch_size, real_data, fake_data)
# Add the gradient penalty to the original discriminator loss
d_loss = d_cost + gp * self.gp_weight
else:
d_loss = self.loss_fn(desc_labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
Expand All @@ -126,7 +167,11 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]:
with tf.GradientTape() as tape:
fake_data = self.generator(random_vector)
predictions = self.discriminator(fake_data)
g_loss = self.loss_fn(misleading_labels, predictions)
if self.use_wgan:
# uses logits
g_loss = self.wgan_generator_loss(predictions)
else:
g_loss = self.loss_fn(misleading_labels, predictions)

grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
Expand Down Expand Up @@ -167,10 +212,10 @@ class ConditionalGAN(keras.Model):
"""
Conditional GAN implementation for labeled and temporally labeled time series.
"""
def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, temporal=False) -> None:
def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, temporal=False, use_wgan=False) -> None:
"""
:param discriminator: A discriminator model which takes a time series as input and check
whether the image is real or fake.
whether the sample is real or fake.
:type discriminator: keras.Model
:param generator: Takes as input a random noise vector of `latent_dim` length and return
a simulated time-series.
Expand Down Expand Up @@ -312,6 +357,7 @@ def train_step(self, data: T.Tuple) -> T.Dict[str, float]:
fake_data = tf.concat([fake_samples, rep_labels], -1)
predictions = self.discriminator(fake_data)
g_loss = self.loss_fn(misleading_labels, predictions)

if self.dp:
# For DP optimizers from `tensorflow.privacy`
self.g_optimizer.minimize(g_loss, self.generator.trainable_weights, tape=tape)
Expand Down

0 comments on commit 2b138cb

Please sign in to comment.