-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathCTGAN.py
84 lines (65 loc) · 2.56 KB
/
CTGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from sklearn.preprocessing import MinMaxScaler
import ctgan
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import glob
import os
import sys
from time import time
# global plot_path
original = sys.stdout
def main(root, i, dopca=False, minmax=False):
print("[INFO] Make nescessary directories")
os.makedirs("../OutputFiles/CTGAN/plots/", exist_ok=True)
os.makedirs("../OutputFiles/CTGAN/gendata/", exist_ok=True)
os.makedirs("../OutputFiles/CTGAN/LossDetails/", exist_ok=True)
model_path = root
dest_path_gen = os.path.join("../OutputFiles/CTGAN/", "gendata/")
plot_path = os.path.join("../OutputFiles/CTGAN/", "plots/")
loss_path = os.path.join("../OutputFiles/CTGAN/", "LossDetails/")
name = f"{dest_path_gen}/gen_sample_model{i}"
print(model_path, dest_path_gen, plot_path, i, name, sep="\n", end="\n*****\n")
print("[INFO] Preprocessing the data")
data = pd.read_csv(model_path, engine='python')
data = data.drop("Unnamed: 0", axis=1)
column_names = data.columns
if minmax:
scaler = MinMaxScaler()
data = pd.DataFrame(scaler.fit_transform(data),
columns=list(column_names))
if dopca:
pca = PCA(0.9)
data = pca.fit_transform(data)
data = pd.DataFrame(data)
print("[INFO] Training the model")
model = ctgan.CTGANSynthesizer(
batch_size=10, generator_dim=(128, 128), discriminator_dim=(128, 128))
sys.stdout = open(f"{loss_path}/loss.txt", "w+")
start = time()
model.fit(data, epochs=1000)
print("[INFO] Generate Samples")
gen_sample4 = model.sample(400)
if dopca:
gen_sample4 = pca.inverse_transform(gen_sample4)
sys.stdout = original
print("Time required = ", time()-start)
gen_sample = pd.DataFrame(gen_sample4, columns=list(column_names))
gen_sample.to_csv(name+"four.csv")
if __name__ == "__main__":
# input data
allfiles = sorted(glob.glob("../InputFiles/Group*Data.csv"))
print(allfiles)
for root in allfiles:
print(f"---- {root} ----")
main(root, root.split("/")[-1][5], True, False)
print("-"*(len(root) + 10))
allfiles = sorted(glob.glob("../InputFiles/Group*Axes.csv"))
print(allfiles)
for root in allfiles:
print(f"---- {root} ----")
main(root, root.split("/")[-1][5], False, False)
print("-"*(len(root) + 10))