Skip to content

Commit

Permalink
Update and standardize dependencies (sdv-dev#126)
Browse files Browse the repository at this point in the history
* Update and standardize dependencies

* increase test epochs

* Simplify tvae test
  • Loading branch information
csala authored Feb 22, 2021
1 parent 02f85b0 commit 4361974
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 87 deletions.
38 changes: 19 additions & 19 deletions conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,35 @@ build:

requirements:
host:
- numpy <2,>=1.17.4
- pandas <1.1.5,>=0.24
- pip
- python
- rdt >=0.2.7,<0.4
- scikit-learn <0.24,>=0.21
- pytorch <2,>=1.0
- torchvision <1,>=0.4.2
- packaging
- pytest-runner
- packaging
- python >=3.6,<3.9
- numpy >=1.17.4,<2
- pandas >=0.24,<1.1.5
- scikit-learn >=0.20,<1
- pytorch >=1.4,<2
- torchvision >=0.5.0,<1
- rdt >=0.2.7,<0.5
run:
- numpy <2,>=1.17.4
- pandas <1.1.5,>=0.24
- python
- rdt >=0.2.7,<0.4
- scikit-learn <0.24,>=0.21
- pytorch <2,>=1.0
- torchvision <1,>=0.4.2
- packaging
- python >=3.6,<3.9
- numpy >=1.17.4,<2
- pandas >=0.24,<1.1.5
- scikit-learn >=0.20,<1
- pytorch >=1.4,<2
- torchvision >=0.5.0,<1
- rdt >=0.2.7,<0.5

about:
home: "https://github.com/sdv-dev/CTGAN"
license: MIT
license_family: MIT
license_file:
license_file:
summary: "Conditional GAN for Tabular Data"
doc_url:
dev_url:
doc_url:
dev_url:

extra:
recipe-maintainers:
- sdv-dev
- sdv-dev
12 changes: 6 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
history = history_file.read()

install_requires = [
'torch<2,>=1.0',
'torchvision<1,>=0.4.2',
'scikit-learn<0.24,>=0.21',
'numpy<2,>=1.17.4',
'pandas<1.1.5,>=0.24',
'rdt>=0.2.7,<0.4',
'packaging',
'numpy>=1.17.4,<2',
'pandas>=0.24,<1.1.5',
'scikit-learn>=0.20,<1',
'torch>=1.4,<2',
'torchvision>=0.5.0,<1',
'rdt>=0.2.7,<0.5',
]

setup_requires = [
Expand Down
75 changes: 13 additions & 62 deletions tests/integration/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,76 +9,27 @@
model are not checked.
"""

import tempfile as tf

import numpy as np
import pandas as pd
from sklearn import datasets

from ctgan.synthesizers.tvae import TVAESynthesizer


def test_tvae_dataframe():
data = pd.DataFrame({
'continuous': np.random.random(1000),
'discrete': np.random.choice(['a', 'b'], 1000)
})
discrete_columns = ['discrete']
def test_tvae(tmpdir):
iris = datasets.load_iris()
data = pd.DataFrame(iris.data, columns=iris.feature_names)
data['class'] = pd.Series(iris.target).map(iris.target_names.__getitem__)

tvae = TVAESynthesizer(epochs=10)
tvae.fit(data, discrete_columns)

sampled = tvae.sample(100)

assert sampled.shape == (100, 2)
assert isinstance(sampled, pd.DataFrame)
assert set(sampled.columns) == {'continuous', 'discrete'}
assert set(sampled['discrete'].unique()) == {'a', 'b'}
tvae.fit(data, ['class'])


def test_tvae_numpy():
data = pd.DataFrame({
'continuous': np.random.random(1000),
'discrete': np.random.choice(['a', 'b'], 1000)
})
discrete_columns = [1]

tvae = TVAESynthesizer(epochs=10)
tvae.fit(data.values, discrete_columns)
path = str(tmpdir / 'test_tvae.pkl')
tvae.save(path)
tvae = TVAESynthesizer.load(path)

sampled = tvae.sample(100)

assert sampled.shape == (100, 2)
assert isinstance(sampled, np.ndarray)
assert set(np.unique(sampled[:, 1])) == {'a', 'b'}


def test_synthesizer_sample():
data = pd.DataFrame({
'discrete': np.random.choice(['a', 'b'], 100)
})
discrete_columns = ['discrete']

tvae = TVAESynthesizer(epochs=1)
tvae.fit(data, discrete_columns)

samples = tvae.sample(1000)
assert isinstance(samples, pd.DataFrame)


def test_save_load():
data = pd.DataFrame({
'continuous': np.random.random(100),
'discrete': np.random.choice(['a', 'b'], 100)
})
discrete_columns = ['discrete']

tvae = TVAESynthesizer(epochs=10)
tvae.fit(data, discrete_columns)

with tf.TemporaryDirectory() as temporary_directory:
tvae.save(temporary_directory + "test_tvae.pkl")
tvae = TVAESynthesizer.load(temporary_directory + "test_tvae.pkl")

sampled = tvae.sample(1000)
assert set(sampled.columns) == {'continuous', 'discrete'}
assert set(sampled['discrete'].unique()) == {'a', 'b'}
assert sampled.shape == (100, 5)
assert isinstance(sampled, pd.DataFrame)
assert set(sampled.columns) == set(data.columns)
assert set(sampled.dtypes) == set(data.dtypes)

0 comments on commit 4361974

Please sign in to comment.