Skip to content

Commit

Permalink
chore: resolve ydataai#348 and modify partial indentation
Browse files Browse the repository at this point in the history
  • Loading branch information
T0217 committed Aug 14, 2024
1 parent 251facc commit 79a8c82
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 35 deletions.
50 changes: 27 additions & 23 deletions src/ydata_synthetic/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import tqdm

from numpy import array, vstack, ndarray
from numpy import array, ndarray, concatenate
from numpy.random import normal
from pandas.api.types import is_float_dtype, is_integer_dtype
from pandas import DataFrame
Expand All @@ -30,16 +30,16 @@
from ydata_synthetic.synthesizers.saving_keras import make_keras_picklable

_model_parameters = ['batch_size', 'lr', 'betas', 'layers_dim', 'noise_dim',
'n_cols', 'seq_len', 'condition', 'n_critic', 'n_features',
'tau_gs', 'generator_dims', 'critic_dims', 'l2_scale',
'n_cols', 'seq_len', 'condition', 'n_critic', 'n_features',
'tau_gs', 'generator_dims', 'critic_dims', 'l2_scale',
'latent_dim', 'gp_lambda', 'pac', 'gamma', 'tanh']
_model_parameters_df = [128, 1e-4, (None, None), 128, 264,
None, None, None, 1, None, 0.2, [256, 256],
None, None, None, 1, None, 0.2, [256, 256],
[256, 256], 1e-6, 128, 10.0, 10, 1, False]

_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval',
'labels', 'n_clusters', 'epsilon', 'log_frequency',
'measurement_cols', 'sequence_length', 'number_sequences',
_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval',
'labels', 'n_clusters', 'epsilon', 'log_frequency',
'measurement_cols', 'sequence_length', 'number_sequences',
'sample_length', 'rounds']

ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df)
Expand Down Expand Up @@ -74,8 +74,8 @@ def fit(self, data: Union[DataFrame, array],
"""
...
@abstractmethod
def sample(self, n_samples:int) -> pd.DataFrame:
assert n_samples>0, "Please insert a value bigger than 0 for n_samples parameter."
def sample(self, n_samples: int) -> pd.DataFrame:
assert n_samples > 0, "Please insert a value bigger than 0 for n_samples parameter."
...

@classmethod
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
except (ValueError, RuntimeError):
# Invalid device or cannot modify virtual devices once initialized.
pass
#Validate the provided model parameters
# Validate the provided model parameters
if model_parameters.betas is not None:
assert len(model_parameters.betas) == 2, "Please provide the betas information as a tuple."

Expand All @@ -128,7 +128,7 @@ def __init__(
self.pac = model_parameters.pac

self.use_tanh = model_parameters.tanh
self.processor=None
self.processor = None
if self.__MODEL__ in RegularModels.__members__ or \
self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL:
self.tau = model_parameters.tau_gs
Expand All @@ -140,12 +140,12 @@ def __call__(self, inputs, **kwargs):
# pylint: disable=C0103
def _set_lr(self, lr):
if isinstance(lr, float):
self.g_lr=lr
self.d_lr=lr
elif isinstance(lr,(list, tuple)):
assert len(lr)==2, "Please provide a two values array for the learning rates or a float."
self.g_lr=lr[0]
self.d_lr=lr[1]
self.g_lr = lr
self.d_lr = lr
elif isinstance(lr, (list, tuple)):
assert len(lr) == 2, "Please provide a two values array for the learning rates or a float."
self.g_lr = lr[0]
self.d_lr = lr[1]

def define_gan(self):
"""Define the trainable model components.
Expand Down Expand Up @@ -187,7 +187,7 @@ def fit(self,
elif self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL:
n_clusters = train_arguments.n_clusters
epsilon = train_arguments.epsilon
self.processor = CTGANDataProcessor(n_clusters=n_clusters, epsilon=epsilon,
self.processor = CTGANDataProcessor(n_clusters=n_clusters, epsilon=epsilon,
num_cols=num_cols, cat_cols=cat_cols).fit(data)
elif self.__MODEL__ == DoppelGANgerProcessor.SUPPORTED_MODEL:
measurement_cols = train_arguments.measurement_cols
Expand All @@ -211,13 +211,17 @@ def sample(self, n_samples: int):
Returns:
synth_sample (pandas.DataFrame): generated synthetic samples.
"""
steps = n_samples // self.batch_size + 1
steps = n_samples // self.batch_size + \
(1 if n_samples % self.batch_size != 0 else 0)
data = []
for _ in tqdm.trange(steps, desc='Synthetic data generation'):
z = random.uniform([self.batch_size, self.noise_dim], dtype=tf.dtypes.float32)
records = self.generator(z, training=False).numpy()
data.append(records)
return self.processor.inverse_transform(array(vstack(data)))

data = concatenate(data, axis=0)
data = data[:n_samples]
return self.processor.inverse_transform(data)

def save(self, path):
"""
Expand All @@ -226,8 +230,8 @@ def save(self, path):
Args:
path (str): Path to write the synthesizer as a pickle object.
"""
#Save only the generator?
if self.__MODEL__=='WGAN' or self.__MODEL__=='WGAN_GP' or self.__MODEL__=='CWGAN_GP':
# Save only the generator?
if self.__MODEL__ == 'WGAN' or self.__MODEL__ == 'WGAN_GP' or self.__MODEL__ == 'CWGAN_GP':
del self.critic
make_keras_picklable()
dump(self, path)
Expand Down Expand Up @@ -309,7 +313,7 @@ def sample(self, condition: DataFrame) -> ndarray:
Returns:
sample (pandas.DataFrame): A dataframe with the generated synthetic records.
"""
##Validate here if the cond_vector=label_dim
# Validate here if the cond_vector=label_dim
condition = condition.reset_index(drop=True)
n_samples = len(condition)
z_dist = random.uniform(shape=(n_samples, self.noise_dim))
Expand Down
25 changes: 13 additions & 12 deletions src/ydata_synthetic/synthesizers/regular/ctgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import tensorflow_probability as tfp
from ydata_synthetic.synthesizers.regular.ctgan.utils \
import ConditionalLoss, RealDataSampler, ConditionalSampler

from ydata_synthetic.synthesizers.loss import gradient_penalty, Mode as ModeGP
from ydata_synthetic.synthesizers.base import BaseGANModel, ModelParameters, TrainParameters
from ydata_synthetic.preprocessing.regular.ctgan_processor import CTGANDataProcessor
Expand Down Expand Up @@ -79,12 +79,12 @@ def _generator_activation(data):
logits = data[:, col_md.start_idx+1:col_md.end_idx]
data_transformed.append(_gumbel_softmax(logits, tau=tau))
return data, tf.concat(data_transformed, axis=1)

x = Dense(data_dim, kernel_initializer="random_uniform",
bias_initializer="random_uniform",
bias_initializer="random_uniform",
activation=_generator_activation)(x)
return Model(inputs=input, outputs=x)

@staticmethod
def _create_critic_model(input_dim, critic_dims, pac):
"""
Expand Down Expand Up @@ -130,11 +130,11 @@ def fit(self, data: DataFrame, train_arguments: TrainParameters, num_cols: list[

self._real_data_sampler = RealDataSampler(train_data, metadata)
self._conditional_sampler = ConditionalSampler(train_data, metadata, train_arguments.log_frequency)

gen_input_dim = self.latent_dim + self._conditional_sampler.output_dimensions
self._generator_model = self._create_generator_model(
gen_input_dim, self.generator_dims, data_dim, metadata, self.tau)

crt_input_dim = data_dim + self._conditional_sampler.output_dimensions
self._critic_model = self._create_critic_model(crt_input_dim, self.critic_dims, self.pac)

Expand Down Expand Up @@ -173,7 +173,7 @@ def fit(self, data: DataFrame, train_arguments: TrainParameters, num_cols: list[
mask = tf.convert_to_tensor(mask)
fake_z = tf.concat([fake_z, cond], axis=1)
generator_loss = self._train_generator_step(fake_z, cond, mask, metadata)

print(f"Epoch: {epoch} | critic_loss: {critic_loss} | generator_loss: {generator_loss}")

def _train_critic_step(self, real, fake):
Expand Down Expand Up @@ -255,7 +255,8 @@ def sample(self, n_samples: int):
if n_samples <= 0:
raise ValueError("Invalid number of samples.")

steps = n_samples // self.batch_size + 1
steps = n_samples // self.batch_size +\
(1 if n_samples % self.batch_size != 0 else 0)
data = []
for _ in tf.range(steps):
fake_z = tf.random.normal([self.batch_size, self.latent_dim])
Expand All @@ -270,7 +271,7 @@ def sample(self, n_samples: int):
data = np.concatenate(data, 0)
data = data[:n_samples]
return self.processor.inverse_transform(data)

def save(self, path):
"""
Save the CTGAN model in a pickle file.
Expand Down Expand Up @@ -314,9 +315,9 @@ def load(class_dict):
new_instance.processor.__dict__ = class_dict["processor"]

new_instance._generator_model = new_instance._create_generator_model(
class_dict["gen_input_dim"], class_dict["generator_dims"],
class_dict["gen_input_dim"], class_dict["generator_dims"],
class_dict["data_dim"], class_dict["metadata"], class_dict["tau"])

new_instance._generator_model.build((class_dict["batch_size"], class_dict["gen_input_dim"]))
new_instance._generator_model.set_weights(class_dict['generator_model_weights'])
return new_instance
return new_instance

0 comments on commit 79a8c82

Please sign in to comment.