Skip to content

Commit

Permalink
feat: add Fabric Regular Synthesizer to Streamlit app (#252)
Browse files Browse the repository at this point in the history
* feat: Fabric Regular Synthesizer in Streamlit app

* feat: add ydata-sdk as requirement for streamlit

* feat: allow to overwrite default datatype for Fabric Regular Synthesizer

* fix: restore streamlit dependency

* feat: rename the SDK synthesizer, improve documentation

* fix: type exception
  • Loading branch information
aquemy authored Mar 22, 2023
1 parent 149e9ef commit dcdab7f
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 16 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
"streamlit==1.18.1",
"typing-extensions==3.10.0",
"streamlit_pandas_profiling==0.1.3",
"ydata-profiling==4.0.0"
"ydata-profiling==4.0.0",
"ydata-sdk>=0.2.1"
],
},
)
9 changes: 9 additions & 0 deletions src/ydata_synthetic/streamlit_app/About.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def main():
- WGAN
- WGANGP
- CTGAN
- **ydata-sdk Synthesizer**
''')

st.success('''In particular, **ydata-sdk Synthesizer** uses [`ydata-sdk`](https://docs.sdk.ydata.ai/) to leverage the state-of-the-art synthesizer model developed by YData.''')
st.info('''
Using **ydata-sdk Synthesizer** requires a valid token. The token is attached to a Fabric account.
In case you do not have an account, you can create one at https://ydata.ai/ydata-fabric-free-trial.
To obtain the token, please, login to https://fabric.ydata.ai.
The token is available on the homepage once you are connected.
''')

#best practives for synthetic data generation
Expand Down
76 changes: 67 additions & 9 deletions src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Union
import os
import json
import streamlit as st

from ydata.sdk.synthesizers import RegularSynthesizer
from ydata.sdk.common.client import get_client

from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
from ydata_synthetic.synthesizers.regular.model import Model

Expand All @@ -12,7 +17,7 @@ def get_available_models(type: Union[str, DataType]):

dtype = DataType(type)
if dtype == DataType.TABULAR:
models_list = [e.value.upper() for e in Model if e.value not in ['cgan', 'cwgangp']]
models_list = [e.value.upper() for e in Model if e.value not in ['cgan', 'cwgangp']] + ['ydata-sdk Synthesizer']
else:
st.warning('Time-Series models are not yet supported .')
models_list = ([''])
Expand All @@ -35,7 +40,7 @@ def run():
models_list = get_available_models(type=datatype)
model_name = st.selectbox('Select your model', models_list)

if model_name !='':
if model_name not in ['', 'ydata-sdk Synthesizer']:
st.text("Select your synthesizer model parameters")
col1, col2 = st.columns(2)
with col1:
Expand All @@ -50,14 +55,14 @@ def run():

# Create the Train parameters
gan_args = ModelParameters(batch_size=batch_size,
lr=lr,
betas=(beta_1, beta_2),
noise_dim=noise_dim,
layers_dim=layer_dim)
lr=lr,
betas=(beta_1, beta_2),
noise_dim=noise_dim,
layers_dim=layer_dim)

model = init_synth(datatype=datatype, modelname=model_name, model_parameters=gan_args)

if model!=None:
if model != None:
st.text("Set your synthesizer training parameters")
#Get the training parameters
epochs, label_col = training_parameters(model_name, df.columns)
Expand All @@ -72,11 +77,64 @@ def run():
else:
model.fit(data=df, num_cols=num_cols, cat_cols=cat_cols, train_arguments=train_args)

st.success('Synthesizer was trained succesfully!!')

st.success('Synthesizer was trained succesfully!')
st.info(f"The trained model will be saved at {model_path}.")

model.save(model_path)



if model_name == 'ydata-sdk Synthesizer':
valid_token = False
st.text("Model parameters")
col1, col2 = st.columns(2)
with col1:
token = st.text_input("SDK Token", type="password")
os.environ['YDATA_TOKEN'] = token

with col2:
st.write("##")
try:
get_client()
st.text('✅ Valid')
valid_token = True
except Exception:
st.text('❌ Invalid')

if not valid_token:
st.error("""**ydata-sdk Synthesizer requires a valid token.**
In case you do not have an account, please, create one at https://ydata.ai/ydata-fabric-free-trial.
To obtain the token, please, login to https://fabric.ydata.ai.
The token is available on the homepage once you are connected.
""")


with st.expander('**More settings**'):
model_path = st.text_input("Saved trained model to path:", value="trained_synth.pkl")

st.subheader("3. Train your synthesizer")
if st.button('Click here to start the training process', disabled=not valid_token):
model = RegularSynthesizer()
with st.spinner("Please wait while your synthesizer trains..."):
dtypes = {}
for c in num_cols:
dtypes[c] = 'numerical'
for c in cat_cols:
dtypes[c] = 'categorical'
model.fit(X=df, dtypes=dtypes)

st.success('Synthesizer was trained succesfully!')
st.info(f"The trained model will be saved at {model_path}.")

model_data = {
'uid': model.uid,
'token': os.environ['YDATA_TOKEN']
}
with open(model_path, 'w') as outfile:
json.dump(model_data, outfile)




if __name__ == '__main__':
run()
Original file line number Diff line number Diff line change
@@ -1,18 +1,57 @@
import streamlit as st
import json
import os

from ydata.sdk.synthesizers import RegularSynthesizer
from ydata.sdk.common.client import get_client

from ydata_synthetic.streamlit_app.pages.functions.train import DataType
from ydata_synthetic.streamlit_app.pages.functions.generate import load_model, generate_profile

def run():
st.subheader("Generate synthetic data from a trained model")

from_SDK = False
model_data = {}
valid_token = False
col1, col2 = st.columns([4, 2])
with col1:
input_path = st.text_input("Provide the path to a trained model", value="trained_synth.pkl")
# Try to load as a JSON as SDK
try:
f = open(input_path)
model_data = json.load(f)
from_SDK = True
except:
pass

if from_SDK:
token = st.text_input("SDK Token", type="password", value=model_data.get('token'))
os.environ['YDATA_TOKEN'] = token


with col2:
datatype = st.selectbox('Select your data type', (DataType.TABULAR.value,))
datatype=DataType(datatype)

if from_SDK and 'YDATA_TOKEN' in os.environ:
st.write("##")
try:
get_client()
st.text('✅ Valid')
valid_token = True
except Exception:
st.text('❌ Invalid')

if from_SDK and 'token' in model_data and not valid_token:
st.warning("The token used during training is not valid anymore. Please, use a new token.")

if from_SDK and not valid_token:
st.error("""**ydata-sdk Synthesizer requires a valid token.**
In case you do not have an account, please, create one at https://ydata.ai/ydata-fabric-free-trial.
To obtain the token, please, login to https://fabric.ydata.ai.
The token is available on the homepage once you are connected.
""")

col1, col2 = st.columns([4,2])
with col1:
n_samples = st.number_input("Number of samples to generate", min_value=0, value=1000)
Expand All @@ -21,14 +60,18 @@ def run():
sample_path = st.text_input("Synthetic samples file path", value='synthetic.csv')

if st.button('Generate samples'):
#load a trained model
model = load_model(input_path=input_path,
datatype=datatype)
if from_SDK:
model = RegularSynthesizer.get(uid=model_data.get('uid'))

else:
model = load_model(input_path=input_path, datatype=datatype)

st.success('The model was properly loaded and is now ready to generate synthetic samples!')

st.success('Trained model was loaded. You can now generate synthetic samples')

#sample synthetic data
synth_data = model.sample(n_samples)
with st.spinner('Generating samples... This might take time.'):
synth_data = model.sample(n_samples)
st.write(synth_data)

#save the synthetic data samples to a given path
Expand Down

0 comments on commit dcdab7f

Please sign in to comment.