Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load this model directly to generate data after saving it #361

Closed
RedBlue01 opened this issue Apr 10, 2024 · 4 comments
Closed

How to load this model directly to generate data after saving it #361

RedBlue01 opened this issue Apr 10, 2024 · 4 comments
Labels
question General question about the software resolution:resolved The issue was fixed, the question was answered, etc.

Comments

@RedBlue01
Copy link

Environment details

  • CTGAN version: 1.11.0
  • Python version: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0] on linux
  • Operating System: Linux version 6.5.0-14-generic (buildd@lcy02-amd64-110) (x86_64-linux-gnu-gcc-12 (Ubuntu 12.3.0-1ubuntu122.04) 12.3.0, GNU ld (GNU Binutils for Ubuntu) 2.38) Flush stdout buffer for epoch updates #1422.04.1-Ubuntu SMP PREEMPT_DYNAMIC Mon Nov 20 18:15:30 UTC 2

Problem description

This is my code. I want to fit a certain epoch and save the model, and then directly use this model to generate data. But the attempt failed and an error was reported.
`from sdv.single_table import CTGANSynthesizer
synthesizer=CTGANSynthesizer.load(
filepath='/home/visitor/Huang/Analytical-Method/GAN/my_synthesizer_mini_e200NEW.pkl'
)
synthetic_data = synthesizer.sample(num_rows=10)
synthetic_data.to_csv('/home/visitor/Huang/Analytical-Method/GAN/synthetic_data.csv', index=False)

print(synthetic_data)
print('Done')`

What I already tried

I tried to view the anaconda3/envs/AM/lib/python3.10/site-packages/sdv/data_processing/data_processor.py file, but my level is limited and I don’t know how to solve it.
The following is my current situation.

Traceback (most recent call last):
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 761, in _sample_with_progress_bar
    sampled = self._sample_in_batches(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 692, in _sample_in_batches
    sampled_rows = self._sample_batch(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 624, in _sample_batch
    sampled, num_valid = self._sample_rows(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 563, in _sample_rows
    sampled = self._data_processor.reverse_transform(sampled)
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/data_processing/data_processor.py", line 827, in reverse_transform
    raise NotFittedError()
sdv.data_processing.errors.NotFittedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/visitor/Huang/Analytical-Method/GAN/myCTGAN_Fit_mini load.py", line 5, in <module>
    synthetic_data = synthesizer.sample(num_rows=10)
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 800, in sample
    return self._sample_with_progress_bar(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 770, in _sample_with_progress_bar
    handle_sampling_error(output_file_path == TMP_FILE_NAME, output_file_path, error)
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/utils.py", line 112, in handle_sampling_error
    raise type(sampling_error)(error_msg + '\n' + str(sampling_error))
sdv.data_processing.errors.NotFittedError: Error: Sampling terminated. Partial results are stored in a temporary file: .sample.csv.temp. This file will be overridden the next time you sample. Please rename the file if you wish to save these results.
@RedBlue01 RedBlue01 added new Label applied to new issues question General question about the software labels Apr 10, 2024
@npatki
Copy link
Contributor

npatki commented Apr 16, 2024

Hi @RedBlue01, nice to meet you.

The error message seems to indicate that the synthesizer you are loading in was never fitted -- therefore, it is not possible to sample from it. Did you create the original synthesizer (saved as my_synthesizer_mini_e200NEW.pkl)? If so, could you share the code that went into creating that pkl file?

BTW instead of using the CTGAN library directly, I would highly recommend you move to the SDV library. You can access the CTGAN Synthesizer via SDV. Doing so will allow you to make use of additional features -- such as better data pre-processing, customizations such as constraints, and conditional sampling. Here is a tutorial that uses CTGAN via the SDV library.

@npatki npatki added under discussion Issue is currently being discussed and removed new Label applied to new issues labels Apr 16, 2024
@RedBlue01
Copy link
Author

RedBlue01 commented Apr 24, 2024

Hi @npatki , Thank you very much for responding to this question, and I'm sorry I send message until now.
And here's my code about create the original synthesizer:

import pandas as pd

data = pd.read_csv('/home/visitor/Huang/Analytical-Method/column_123after.csv', usecols=[0, 2])

from sdv.metadata import SingleTableMetadata
metadata=SingleTableMetadata()
metadata.detect_from_dataframe(data)
python_dict = metadata.to_dict()
print(data)
print(python_dict)

from sdv.single_table import CTGANSynthesizer
synthesizer = CTGANSynthesizer(
metadata, # required
enforce_rounding=True,
epochs=200,
verbose=True
)
synthesizer.save(
filepath='/home/visitor/Huang/Analytical-Method/GAN/my_synthesizer_e200NEW.pkl'
)

synthesizer.fit(data)
synthesizer.get_loss_values()

synthetic_data = synthesizer.sample(num_rows=10)

print(synthetic_data)
print('Done')

And thank you so much for what you have done. I already "pip install sdv"ed. And it's an amazing work.

@npatki
Copy link
Contributor

npatki commented Apr 24, 2024

Hi @RedBlue01, thanks for sharing your code.

The problem is that you are saving your synthesizer before you are fitting it. I would recommend saving the synthesizer after you call the fit function. The fitting process is where the machine learning happens. You would want to include that in the saved file so saving should happen after that.

synthesizer.fit(data)

synthesizer.save(
filepath='/home/visitor/Huang/Analytical-Method/GAN/my_synthesizer_e200NEW.pkl'
)

Keep in mind that when you call save, you will save the state of the synthesizer at that point of time only, as a pkl file.

@RedBlue01
Copy link
Author

Hi @npatki , thank you so much for your help. I finally successfully solved this problem that has troubled me for a long time. Indeed, I never thought that it was a problem with the order of save and fit.
The code works great. What an amazing work, thank you and your team again for your work and dedication.

@npatki npatki added resolution:resolved The issue was fixed, the question was answered, etc. and removed under discussion Issue is currently being discussed labels Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question General question about the software resolution:resolved The issue was fixed, the question was answered, etc.
Projects
None yet
Development

No branches or pull requests

2 participants