diff --git a/README.md b/README.md index 543c1302..0d8843b2 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,12 @@ Join us on [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the # YData Synthetic A package to generate synthetic tabular and time-series data leveraging the state of the art generative models. + +## 🎊 We have **big news**: v1.0.0 is here +> We have exciting news for you. The new version of `ydata-synthetic` include new and exciting features: + > - A conditional architecture for tabular data: CTGAN, which will make the process of synthetic data generation easier and with higher quality! + > - A new streamlit app that delivers the synthetic data generation experience with a UI interface + ## Synthetic data ### What is synthetic data? Synthetic data is artificially generated data that is not collected from real world events. It replicates the statistical components of real data without containing any identifiable information, ensuring individuals' privacy. @@ -27,19 +33,54 @@ This repository contains material related with Generative Adversarial Networks f It consists a set of different GANs architectures developed using Tensorflow 2.0. Several example Jupyter Notebooks and Python scripts are included, to show how to use the different architectures. ## Quickstart - The source code is currently hosted on GitHub at: https://github.com/ydataai/ydata-synthetic Binary installers for the latest released version are available at the [Python Package Index (PyPI).](https://pypi.org/project/ydata-synthetic/) -``` +```commandline pip install ydata-synthetic ``` +### The UI guide for synthetic data generation + +YData synthetic has now a UI interface to guide you through the steps and inputs to generate structure tabular data. +The streamlit app is available form *v1.0.0* onwards, and supports the following flows: +- Train a synthesizer model +- Generate & profile synthetic data samples + +#### Installation + +```commandline +pip install ydata-syntehtic[streamlit] +``` +#### Quickstart +Use the code snippet below in a python file (Jupyter Notebooks are not supported): +```python +from ydata_synthetic import streamlit_app + +streamlit_app.run() +``` + +Or use the file streamlit_app.py that can be found in the [examples folder](https://github.com/ydataai/ydata-synthetic/tree/master/examples/streamlit_app.py). + +```commandline +python -m streamlit_app +``` + +The below models are supported: + - CGAN + - WGAN + - WGANGP + - DRAGAN + - CRAMER + - CTGAN + +[![Watch the video](assets/streamlit_app.png)](https://youtu.be/ep0PhwsFx0A) + ### Examples Here you can find usage examples of the package and models to synthesize tabular data. - Synthesizing the minority class with VanillaGAN on credit fraud dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/gan_example.ipynb) - Time Series synthetic data generation with TimeGAN on stock dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/timeseries/TimeGAN_Synthetic_stock_data.ipynb) - - More examples are continously added and can be found in `/examples` directory. + - More examples are continuously added and can be found in `/examples` directory. ### Datasets for you to experiment Here are some example datasets for you to try with the synthesizers: @@ -51,7 +92,6 @@ Here are some example datasets for you to try with the synthesizers: #### Sequential datasets - [Stock data](https://github.com/ydataai/ydata-synthetic/tree/master/data) - ## Project Resources In this repository you can find the several GAN architectures that are used to create synthesizers: @@ -64,6 +104,7 @@ In this repository you can find the several GAN architectures that are used to c - [DRAGAN (On Convergence and stability of GANS)](https://arxiv.org/pdf/1705.07215.pdf) - [Cramer GAN (The Cramer Distance as a Solution to Biased Wasserstein Gradients)](https://arxiv.org/abs/1705.10743) - [CWGAN-GP (Conditional Wassertein GAN with Gradient Penalty)](https://cameronfabbri.github.io/papers/conditionalWGAN.pdf) + - [CTGAN (Conditional Tabular GAN)](https://arxiv.org/pdf/1907.00503.pdf) ### Sequential data - [TimeGAN](https://papers.nips.cc/paper/2019/file/c9efe5f26cd17ba6216bbe2a7d26d490-Paper.pdf) diff --git a/assets/streamlit_app.png b/assets/streamlit_app.png new file mode 100644 index 00000000..fd283c32 Binary files /dev/null and b/assets/streamlit_app.png differ diff --git a/examples/regular/streamlit app/.streamlit/config.toml b/examples/regular/streamlit app/.streamlit/config.toml deleted file mode 100644 index 45288de9..00000000 --- a/examples/regular/streamlit app/.streamlit/config.toml +++ /dev/null @@ -1,5 +0,0 @@ -[theme] -primaryColor="#040000" -backgroundColor="#770303" -secondaryBackgroundColor="#000000" -textColor="#f2f2f3" \ No newline at end of file diff --git a/examples/regular/streamlit app/README.md b/examples/regular/streamlit app/README.md deleted file mode 100644 index d5b4a2ab..00000000 --- a/examples/regular/streamlit app/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# Streamlit application to generate synthetic data using ydata-synthetic - -streamlit app to generate synthetic data - -This application takes a pre-processed dataset as input and outputs a synthetic dataset based on the given input parameters. This is made with open source libraries streamlit, ydata-synthetic and deployed on the streamlit cloud. - -## How to use - -1. Upload a pre-processed dataset. -2. Choose the numerical features and categorical features. -3. Choose all the training parameters appropriately. -4. Click the 'click here to start the training process' button. - -streamlit app to generate synthetic data - -Wait for the training to end. You will see a graph comparing the original data and synthetic data after training. -Please use less number of epochs to complete the training process quickly as this application is deployed on the community cloud of streamlit which has computational limits. - -## Contributing - -Find the application here in this link [![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://share.streamlit.io/rajeshai/ydata-synthetic-streamlit/main/app.py) - -Feel free to contribute to this app by adding more features and optimizing its performance further. diff --git a/examples/regular/streamlit app/YData_logo.svg b/examples/regular/streamlit app/YData_logo.svg deleted file mode 100644 index c17ecdba..00000000 --- a/examples/regular/streamlit app/YData_logo.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/examples/regular/streamlit app/app.JPG b/examples/regular/streamlit app/app.JPG deleted file mode 100644 index f1cd6623..00000000 Binary files a/examples/regular/streamlit app/app.JPG and /dev/null differ diff --git a/examples/regular/streamlit app/app.gif b/examples/regular/streamlit app/app.gif deleted file mode 100644 index f652ec66..00000000 Binary files a/examples/regular/streamlit app/app.gif and /dev/null differ diff --git a/examples/regular/streamlit app/app.py b/examples/regular/streamlit app/app.py deleted file mode 100644 index f0465bd8..00000000 --- a/examples/regular/streamlit app/app.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -import streamlit as st -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -from ydata_synthetic.synthesizers.regular import DRAGAN, CGAN, CRAMERGAN, WGAN_GP -from ydata_synthetic.synthesizers import ModelParameters, TrainParameters - -st.set_page_config(layout="wide",initial_sidebar_state="auto") -os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices' -def run(): - #global data_synn - st.sidebar.image('YData_logo.svg') - st.title('Generate synthetic data for a tabular classification dataset using [ydata-synthetic](https://github.com/ydataai/ydata-synthetic)') - st.markdown('This streamlit application can generate synthetic data for your dataset. Please read all the instructions in the sidebar before you start the process.') - data = st.file_uploader('Upload a preprocessed dataset in csv format') - st.sidebar.title('About') - st.sidebar.markdown('[ydata-synthetic](https://github.com/ydataai/ydata-synthetic) is an open-source library and is used to generate synthetic data mimicking the real world data.') - st.sidebar.header('What is synthetic data?') - st.sidebar.markdown('Synthetic data is artificially generated data that is not collected from real world events. It replicates the statistical components of real data without containing any identifiable information, ensuring individuals privacy.') - st.sidebar.header('Why Synthetic Data?') - st.sidebar.markdown('''Synthetic data can be used for many applications: -- Privacy -- Remove bias -- Balance datasets -- Augment datasets''') - - - st.sidebar.header('Steps to follow') - st.sidebar.markdown(''' -- Upload any preprocessed tabular classification dataset. -- Choose the parameters in the adjacent window appropriately. -- Since this is a demo, please choose less number of epochs for quick completion of training. -- After choosing all parameters, Click the button under the parameters to start training. -- After the training is complete, you will see a graph comparing both real data set and synthetic dataset. Categorical columns are used to compare. -- You will also see a button to download your synthetic dataset. Click that button to download your dataset.''') - - st.sidebar.markdown('''[![Repo](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/ydataai/ydata-synthetic)''',unsafe_allow_html=True) - - @st.cache - def train(df): - #models_dir = './cache' - gan_args = ModelParameters(batch_size=batch_size, - lr=learning_rate*0.001, - betas=(beta_1, beta_2), - noise_dim=noise_dim, - layers_dim=layer_dim) - - train_args = TrainParameters(epochs=epochs, - sample_interval=log_step) - synthesizer = model(gan_args, n_discriminator=3) - synthesizer.train(data, train_args, num_cols, cat_cols) - synthesizer.save('data_synth.pkl') - synthesizer = model.load('data_synth.pkl') - data_syn = synthesizer.sample(samples) - return data_syn - @st.cache - def convert_df(df): - return df.to_csv().encode('utf-8') - if data is not None: - data = pd.read_csv(data) - data.dropna(inplace=True) - st.header('Choose the parameters!!') - col1, col2, col3,col4 = st.columns(4) - with col1: - model = st.selectbox('Choose the GAN model', ['DRAGAN','CGAN','CRAMEGAN','WGAN_GP'],key=1) - if model=='DRAGAN': - model = DRAGAN - elif model=='CGAN': - model=CGAN - elif model=='CRAMEGAN': - model = CRAMERGAN - else: - model = WGAN_GP - num_cols = st.multiselect('Choose the numerical columns', data.columns,key=1) - cat_cols = st.multiselect('Choose categorical columns', [x for x in data.columns if x not in num_cols], key=2) - - with col2: - noise_dim = st.number_input('Select noise dimension', 0,200,128,1) - layer_dim = st.number_input('Select the layer dimension', 0,200,128,1) - batch_size = st.number_input('Select batch size', 0,500, 500,1) - - with col3: - log_step = st.number_input('Select sample interval', 0,200,100,1) - epochs = st.number_input('Select the number of epochs',0,50,2,1) - learning_rate = st.number_input('Select learning rate(x1e-3', 0.01, 0.1, 0.05, 0.01) - - with col4: - beta_1 = st.slider('Select first beta co-efficient', 0.0, 1.0, 0.5) - beta_2 = st.slider('Select second beta co-efficient', 0.0, 1.0, 0.9) - samples = st.number_input('Select the number of synthetic samples to be generated', 0, 400000, step=1000) - if st.button('Click here to start the training process'): - if data is not None: - st.write('Model Training is in progress. It may take a few minutes. Please wait for a while.') - data_synn = train(data) - st.success('Synthetic dataset with the given number of samples is generated!!') - st.subheader('Real Data vs Synthetic Data') - f , axes = plt.subplots(len(cat_cols),2, figsize=(20,25)) - f.suptitle('Real data vs Synthetic data') - for i, j in enumerate(cat_cols): - sns.countplot(x=j, data=data, ax = axes[i,0]) - sns.countplot(x=j, data=data_synn, ax = axes[i,1]) - st.pyplot(f) - st.download_button( - label="Download data as CSV", - data=convert_df(data_synn), - file_name='data_syn.csv', - mime='text/csv') - st.balloons() - else: - st.write('Upload a dataset to train!!') -if __name__== '__main__': - run() \ No newline at end of file diff --git a/examples/regular/streamlit app/requirements.txt b/examples/regular/streamlit app/requirements.txt deleted file mode 100644 index 87231716..00000000 --- a/examples/regular/streamlit app/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -pandas -matplotlib -numpy -seaborn -streamlit -ydata-synthetic diff --git a/examples/streamlit_app.py b/examples/streamlit_app.py new file mode 100644 index 00000000..e094889e --- /dev/null +++ b/examples/streamlit_app.py @@ -0,0 +1,7 @@ +""" + Python file example with the script to run ydata-synthetic streamlit app +""" +from ydata_synthetic import streamlit_app + +if __name__ == '__main__': + streamlit_app.run() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 71f321a2..48dc1793 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ easydict==1.10 pmlb==1.0.* tqdm<5.0 typeguard==2.13.* -pytest==6.2.* \ No newline at end of file +pytest==6.2.* diff --git a/setup.py b/setup.py index 0d3aa01c..290de253 100644 --- a/setup.py +++ b/setup.py @@ -47,4 +47,13 @@ package_dir={'':'src'}, include_package_data=True, options={"bdist_wheel": {"universal": True}}, - install_requires=requirements) + install_requires=requirements, + extras_require={ + "streamlit": [ + "streamlit==0.18.1", + "typing-extensions==3.10.0", + "streamlit_pandas_profiling==0.1.3", + "ydata-profiling==4.0.0" + ], + }, + ) diff --git a/src/ydata_synthetic/streamlit_app/.streamlit/config.toml b/src/ydata_synthetic/streamlit_app/.streamlit/config.toml new file mode 100644 index 00000000..12f51c67 --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/.streamlit/config.toml @@ -0,0 +1,3 @@ +[theme] +base="light" +primaryColor="#e32212" diff --git a/src/ydata_synthetic/streamlit_app/About.py b/src/ydata_synthetic/streamlit_app/About.py new file mode 100644 index 00000000..19f28fb3 --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/About.py @@ -0,0 +1,83 @@ +""" + ydata-synthetic streamlit app landing page +""" +import streamlit as st + +def main(): + st.set_page_config( + page_title="YData Synthetic - Synthetic data generation streamlit_app", + page_icon="👋", + layout="wide" + ) + col1, col2 = st.columns([2, 4]) + + with col1: + st.image("https://assets.ydata.ai/oss/ydata-synthetic-_red.png", width=200) + + with col2: + st.title("Welcome to YData Synthetic!") + st.text("Your application for synthetic data generation!") + + st.markdown('[ydata-synthetic](https://github.com/ydataai/ydata-synthetic) is an open-source library and is used to generate synthetic data mimicking the real world data.') + st.header('What is synthetic data?') + st.markdown('Synthetic data is artificially generated data that is not collected from real-world events. It replicates the statistical components of real data containing no identifiable information, ensuring an individual’s privacy.') + st.header('Why Synthetic Data?') + st.markdown(''' + Synthetic data can be used for many applications: + - Privacy + - Remove bias + - Balance datasets + - Augment datasets''') + + # read the instructions in x/ + st.markdown('This *streamlit_app* application can generate synthetic data for your dataset. ' + 'Please read all the instructions in the sidebar before you start the process.') + + # read the instructions in x/ + st.subheader('Select & train a synthesizer') + #Add here the example text for the end users + + st.markdown(''' + `ydata-synthetic` streamlit app enables the training and generation of synthetic data from generative architectures. + The current app only provides support for the generation tabular data and for the following architectures: + - GAN + - WGAN + - WGANGP + - CTGAN + ''') + + #best practives for synthetic data generation + st.markdown(''' + ##### What you should ensure before training the synthesizer: + - Make sure your dataset has no missing data. + - If missing data is a problem, no worries. Check the article and this article. + - Make sure you choose the right number of epochs and batch_size considering your dataset shape. + - The choice of these 2 parameters highly affects the results you may get. + - Make sure that you've the right data types selected. + - Only numerical and categorical values are supported. + - In case date , datetime, or text is available in the dataset, the columns should be preprocessed before the model training.''') + + st.markdown('The trained synthesizer is saved to `*.trained_synth.pkl*` by default.') + + st.subheader('Generate & compare synthetic samples') + + st.markdown(''' + The ydata-synthetic app experience allows you to: + - Generate as many samples as you want based on the provided input + - Generate a profile for the generated synthetic samples + - Save the generated samples to a local directory''') + + # guidelines for sampling and + st.markdown(''' + ##### What you should ensure before generating synthetic samples: + - If no model file path is provided, the default location `.trained_synth.pkl` is assumed. + - Always choose the correct type of data, that corresponds to the trained model in order to avoid loading errors.''') + + st.subheader('Coming soon') + st.markdown(''' + - Support for time-series models: TimeGAN + - Integrate more advanced settings for CTGAN + - Side-by-side comparison real vs synthetic data sample with `ydata-profiling`''') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/ydata_synthetic/streamlit_app/__init__.py b/src/ydata_synthetic/streamlit_app/__init__.py new file mode 100644 index 00000000..aa617462 --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/__init__.py @@ -0,0 +1,3 @@ +from ydata_synthetic.streamlit_app.run import run + +## \ No newline at end of file diff --git a/src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py b/src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py new file mode 100644 index 00000000..6251c3ad --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py @@ -0,0 +1,82 @@ +from typing import Union +import streamlit as st + +from ydata_synthetic.synthesizers import ModelParameters, TrainParameters +from ydata_synthetic.synthesizers.regular.model import Model + +from ydata_synthetic.streamlit_app.pages.functions.load_data import upload_file +from ydata_synthetic.streamlit_app.pages.functions.train import DataType, __CONDITIONAL_MODELS +from ydata_synthetic.streamlit_app.pages.functions.train import init_synth, advanced_setttings, training_parameters + +def get_available_models(type: Union[str, DataType]): + + dtype = DataType(type) + if dtype == DataType.TABULAR: + models_list = [e.value.upper() for e in Model if e.value not in ['cgan', 'cwgangp']] + else: + st.warning('Time-Series models are not yet supported .') + models_list = (['']) + return models_list + +def run(): + model_name= None + + df, num_cols, cat_cols = upload_file() + + if df is not None: + st.subheader("2. Select your synthesizer parameters") + + col_type, col_model = st.columns(2) + + with col_type: + datatype = st.selectbox('Select your data type', (DataType.TABULAR.value, )) + with col_model: + if datatype is not None: + models_list = get_available_models(type=datatype) + model_name = st.selectbox('Select your model', models_list) + + if model_name !='': + st.text("Select your synthesizer model parameters") + col1, col2 = st.columns(2) + with col1: + batch_size = st.number_input('Batch size', 0, 500, 500, 1) + + with col2: + lr = st.number_input('Learning rate', 0.01, 0.1, 0.05, 0.01) + + with st.expander('**More settings**'): + model_path = st.text_input("Saved trained model to path:", value="trained_synth.pkl") + noise_dim, layer_dim, beta_1, beta_2 = advanced_setttings() + + # Create the Train parameters + gan_args = ModelParameters(batch_size=batch_size, + lr=lr, + betas=(beta_1, beta_2), + noise_dim=noise_dim, + layers_dim=layer_dim) + + model = init_synth(datatype=datatype, modelname=model_name, model_parameters=gan_args) + + if model!=None: + st.text("Set your synthesizer training parameters") + #Get the training parameters + epochs, label_col = training_parameters(model_name, df.columns) + + train_args = TrainParameters(epochs=epochs) + + st.subheader("3. Train your synthesizer") + if st.button('Click here to start the training process'): + with st.spinner("Please wait while your synthesizer trains..."): + if label_col is not None: + model.fit(data=df, num_cols=num_cols, cat_cols=cat_cols, train_arguments=train_args, label_cols=label_col) + else: + model.fit(data=df, num_cols=num_cols, cat_cols=cat_cols, train_arguments=train_args) + + st.success('Synthesizer was trained succesfully!!') + + st.info(f"The trained model will be saved at {model_path}.") + + model.save(model_path) + +if __name__ == '__main__': + run() \ No newline at end of file diff --git a/src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py b/src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py new file mode 100644 index 00000000..4e57a871 --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py @@ -0,0 +1,41 @@ +import streamlit as st + +from ydata_synthetic.streamlit_app.pages.functions.train import DataType +from ydata_synthetic.streamlit_app.pages.functions.generate import load_model, generate_profile + +def run(): + st.subheader("Generate synthetic data from a trained model") + + col1, col2 = st.columns([4, 2]) + with col1: + input_path = st.text_input("Provide the path to a trained model", value="trained_synth.pkl") + with col2: + datatype = st.selectbox('Select your data type', (DataType.TABULAR.value,)) + datatype=DataType(datatype) + + col1, col2 = st.columns([4,2]) + with col1: + n_samples = st.number_input("Number of samples to generate", min_value=0, value=1000) + profile = st.checkbox("Generate synthetic data profiling?", value=False) + with col2: + sample_path = st.text_input("Synthetic samples file path", value='synthetic.csv') + + if st.button('Generate samples'): + #load a trained model + model = load_model(input_path=input_path, + datatype=datatype) + + st.success('Trained model was loaded. You can now generate synthetic samples') + + #sample synthetic data + synth_data = model.sample(n_samples) + st.write(synth_data) + + #save the synthetic data samples to a given path + synth_data.to_csv(sample_path) + + if profile: + generate_profile(df=synth_data) + +if __name__ == '__main__': + run() \ No newline at end of file diff --git a/src/ydata_synthetic/streamlit_app/pages/functions/__init__.py b/src/ydata_synthetic/streamlit_app/pages/functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ydata_synthetic/streamlit_app/pages/functions/generate.py b/src/ydata_synthetic/streamlit_app/pages/functions/generate.py new file mode 100644 index 00000000..cb098119 --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/pages/functions/generate.py @@ -0,0 +1,22 @@ +""" + Auxiliary functions for the synthetic data generation +""" +#passar o datatype para outro sítio?? +import pandas as pd +from ydata_profiling import ProfileReport +from streamlit_pandas_profiling import st_profile_report + +from ydata_synthetic.streamlit_app.pages.functions.train import DataType +from ydata_synthetic.synthesizers.regular import RegularSynthesizer +from ydata_synthetic.synthesizers.timeseries import TimeGAN + +def load_model(input_path: str, datatype: DataType): + if datatype == DataType.TABULAR: + model = RegularSynthesizer.load(input_path) + else: + model = TimeGAN.load(input_path) + return model + +def generate_profile(df: pd.DataFrame): + report = ProfileReport(df, title='Synthetic data profile', interactions=None) + st_profile_report(report) \ No newline at end of file diff --git a/src/ydata_synthetic/streamlit_app/pages/functions/load_data.py b/src/ydata_synthetic/streamlit_app/pages/functions/load_data.py new file mode 100644 index 00000000..7ff3051f --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/pages/functions/load_data.py @@ -0,0 +1,25 @@ +import streamlit as st +import pandas as pd + +def upload_file(): + df = None + num_cols = None + cat_cols = None + + st.subheader("1. Select your dataset") + uploaded_file = st.file_uploader("Choose a file:") + + if uploaded_file is not None: + df = pd.read_csv(uploaded_file) + st.write(df) + + #add here more things for the mainpage + if df is not None: + col1, col2 = st.columns(2) + with col1: + num_cols = st.multiselect('Choose the numerical columns', df.columns, key=1) + with col2: + cat_cols = st.multiselect('Choose categorical columns', [x for x in df.columns if x not in num_cols], key=2) + + return df, num_cols, cat_cols + diff --git a/src/ydata_synthetic/streamlit_app/pages/functions/train.py b/src/ydata_synthetic/streamlit_app/pages/functions/train.py new file mode 100644 index 00000000..c2ede838 --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/pages/functions/train.py @@ -0,0 +1,50 @@ +""" + Auxiliary functions for synthetic data training +""" +from enum import Enum +import streamlit as st + +from ydata_synthetic.synthesizers.regular import RegularSynthesizer +from ydata_synthetic.synthesizers.timeseries import TimeGAN +from ydata_synthetic.synthesizers import ModelParameters + +__MODEL_MAPPING = {'tabular': RegularSynthesizer, 'timeseries': TimeGAN} +__CONDITIONAL_MODELS = ['CGAN', 'CWGANGP'] + +class DataType(Enum): + TABULAR = 'tabular' + TIMESERIES = 'timeseries' + +def init_synth(datatype: str, modelname: str, model_parameters: ModelParameters, n_critic: int=1): + synth = __MODEL_MAPPING[datatype] + modelname = modelname.lower() + if modelname in ['wgan', 'cwgangp', 'wgangp']: + synth = synth(modelname=modelname, + model_parameters=model_parameters, + n_critic=n_critic) + else: + synth = synth(modelname=modelname, + model_parameters=model_parameters) + return synth + +def advanced_setttings(): + col1, col2 = st.columns(2) + with col1: + noise_dim = st.number_input('Select noise dimension', 0, 200, 128, 1) + layer_dim = st.number_input('Select the layer dimension', 0, 200, 128, 1) + with col2: + beta_1 = st.slider('Select first beta co-efficient', 0.0, 1.0, 0.5) + beta_2 = st.slider('Select second beta co-efficient', 0.0, 1.0, 0.9) + return noise_dim, layer_dim, beta_1, beta_2 + +def training_parameters(model_name:str, df_cols: list): + col1, col2 = st.columns([2, 4]) + with col1: + epochs = st.number_input('Epochs', min_value=0, value=100) + + if model_name in __CONDITIONAL_MODELS: + with col2: + label_col = st.multiselect('Choose the conditional cols:', df_cols) + else: + label_col=None + return epochs, label_col \ No newline at end of file diff --git a/src/ydata_synthetic/streamlit_app/run.py b/src/ydata_synthetic/streamlit_app/run.py new file mode 100644 index 00000000..5e416b4e --- /dev/null +++ b/src/ydata_synthetic/streamlit_app/run.py @@ -0,0 +1,15 @@ +""" + Logic to run streamlit app from python code +""" +import os +from streamlit import config as _config +from streamlit.web import bootstrap + +def run(): + dir_path = os.path.dirname(__file__) + file_path = os.path.join(dir_path, "About.py") + + _config.set_option("server.headless", True) + args = [] + + bootstrap.run(file_path,'',args, flag_options={}) \ No newline at end of file