diff --git a/generator/generative_model.py b/generator/generative_model.py index 516c9c2..a754138 100644 --- a/generator/generative_model.py +++ b/generator/generative_model.py @@ -7,9 +7,9 @@ from generator.options import Options -class GenerativeModelWrapper: +class GenerativeModel: """ - A wrapper class for generative models with a scikit-learn-like API. + A wrapper class for generative models. """ def __init__(self, model_name: str, model_params: dict = None): @@ -41,7 +41,6 @@ def _initialize_model(self): "acgan": ACGAN, "diffusion_ts": Diffusion_TS, "diffcharge": DDPM, - # Add other models as needed } if self.model_name in model_dict: model_class = model_dict[self.model_name] @@ -49,7 +48,7 @@ def _initialize_model(self): else: raise ValueError(f"Model {self.model_name} not recognized.") - def fit(self, X, y=None): + def fit(self, X): """ Train the model on the given dataset.