diff --git a/eeggan/Generate_Samples.py b/eeggan/Generate_Samples.py index 7d4a34c..1f2e655 100644 --- a/eeggan/Generate_Samples.py +++ b/eeggan/Generate_Samples.py @@ -5,9 +5,9 @@ import pandas as pd import torch -from helpers import system_inputs -from helpers.trainer import Trainer -from nn_architecture.models import TtsGenerator, TtsGeneratorFiltered +from eeggan.helpers import system_inputs +from eeggan.helpers.trainer import Trainer +from eeggan.nn_architecture.models import TtsGenerator, TtsGeneratorFiltered def generate_samples(argv = []): diff --git a/eeggan/Train_Gan.py b/eeggan/Train_Gan.py index 626f33b..de27887 100644 --- a/eeggan/Train_Gan.py +++ b/eeggan/Train_Gan.py @@ -5,12 +5,12 @@ import torch import torch.multiprocessing as mp -from helpers.trainer import Trainer -from helpers.get_master import find_free_port -from helpers.ddp_training import run, DDPTrainer -from nn_architecture.models import TtsDiscriminator, TtsGenerator, TtsGeneratorFiltered -from helpers.dataloader import Dataloader -from helpers import system_inputs +from eeggan.helpers.trainer import Trainer +from eeggan.helpers.get_master import find_free_port +from eeggan.helpers.ddp_training import run, DDPTrainer +from eeggan.nn_architecture.models import TtsDiscriminator, TtsGenerator, TtsGeneratorFiltered +from eeggan.helpers.dataloader import Dataloader +from eeggan.helpers import system_inputs """Implementation of the training process of a GAN for the generation of synthetic sequential data. diff --git a/eeggan/Visualize_Gan.py b/eeggan/Visualize_Gan.py index 0a7bee0..07f6643 100644 --- a/eeggan/Visualize_Gan.py +++ b/eeggan/Visualize_Gan.py @@ -7,11 +7,11 @@ import numpy as np import torch -from helpers import system_inputs -from nn_architecture import models -from helpers.dataloader import Dataloader -from helpers.visualize_pca import visualization_dim_reduction -from helpers.visualize_spectogram import plot_spectogram, plot_fft_hist +from eeggan.helpers import system_inputs +from eeggan.nn_architecture import models +from eeggan.helpers.dataloader import Dataloader +from eeggan.helpers.visualize_pca import visualization_dim_reduction +from eeggan.helpers.visualize_spectogram import plot_spectogram, plot_fft_hist class PlotterGanTraining: """This class is used to read samples from a csv-file and plot them. diff --git a/eeggan/helpers/ddp_training.py b/eeggan/helpers/ddp_training.py index 2609a9f..4a83f14 100644 --- a/eeggan/helpers/ddp_training.py +++ b/eeggan/helpers/ddp_training.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -import helpers.trainer as trainer -from helpers.dataloader import Dataloader +import eeggan.helpers.trainer as trainer +from eeggan.helpers.dataloader import Dataloader class DDPTrainer(trainer.Trainer): diff --git a/eeggan/helpers/trainer.py b/eeggan/helpers/trainer.py index 848c851..133ef0d 100644 --- a/eeggan/helpers/trainer.py +++ b/eeggan/helpers/trainer.py @@ -3,8 +3,8 @@ import torch import numpy as np -from nn_architecture import losses, models -from nn_architecture.losses import WassersteinGradientPenaltyLoss as Loss +from eeggan.nn_architecture import losses, models +from eeggan.nn_architecture.losses import WassersteinGradientPenaltyLoss as Loss # https://machinelearningmastery.com/how-to-implement-wasserstein-loss-for-generative-adversarial-networks/ # For implementation of Wasserstein-GAN see link above diff --git a/eeggan/nn_architecture/models.py b/eeggan/nn_architecture/models.py index 0504641..800637d 100644 --- a/eeggan/nn_architecture/models.py +++ b/eeggan/nn_architecture/models.py @@ -5,7 +5,7 @@ from scipy import signal import numpy as np -from nn_architecture.ttsgan_components import * +from eeggan.nn_architecture.ttsgan_components import * # insert here all different kinds of generators and discriminators