Skip to content

Commit

Permalink
use single quotes
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Mar 25, 2024
1 parent 88127ad commit 2473d50
Show file tree
Hide file tree
Showing 13 changed files with 358 additions and 356 deletions.
14 changes: 7 additions & 7 deletions deepecho/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Top-level package for DeepEcho."""

__author__ = "DataCebo, Inc."
__email__ = "info@sdv.dev"
__version__ = "0.5.1.dev0"
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
__author__ = 'DataCebo, Inc.'
__email__ = 'info@sdv.dev'
__version__ = '0.5.1.dev0'
__path__ = __import__('pkgutil').extend_path(__path__, __name__)

from deepecho.demo import load_demo
from deepecho.models.basic_gan import BasicGANModel
from deepecho.models.par import PARModel

__all__ = [
"load_demo",
"BasicGANModel",
"PARModel",
'load_demo',
'BasicGANModel',
'PARModel',
]
4 changes: 2 additions & 2 deletions deepecho/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import pandas as pd

_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')


def load_demo():
"""Load the demo DataFrame."""
return pd.read_csv(
os.path.join(_DATA_PATH, "demo.csv"), parse_dates=["date"]
os.path.join(_DATA_PATH, 'demo.csv'), parse_dates=['date']
)
2 changes: 1 addition & 1 deletion deepecho/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from deepecho.models.basic_gan import BasicGANModel
from deepecho.models.par import PARModel

__all__ = ["PARModel", "BasicGANModel"]
__all__ = ['PARModel', 'BasicGANModel']
44 changes: 22 additions & 22 deletions deepecho/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ def _validate(sequences, context_types, data_types):
See `fit`.
"""
dtypes = set([
"continuous",
"categorical",
"ordinal",
"count",
"datetime",
'continuous',
'categorical',
'ordinal',
'count',
'datetime',
])
assert all(dtype in dtypes for dtype in context_types)
assert all(dtype in dtypes for dtype in data_types)

for sequence in sequences:
assert len(sequence["context"]) == len(context_types)
assert len(sequence["data"]) == len(data_types)
lengths = [len(x) for x in sequence["data"]]
assert len(sequence['context']) == len(context_types)
assert len(sequence['data']) == len(data_types)
lengths = [len(x) for x in sequence['data']]
assert len(set(lengths)) == 1

def fit_sequences(self, sequences, context_types, data_types):
Expand Down Expand Up @@ -93,15 +93,15 @@ def _get_data_types(data, data_types, columns):
else:
dtype = data[column].dtype
kind = dtype.kind
if kind in "fiud":
dtypes_list.append("continuous")
elif kind in "OSUb":
dtypes_list.append("categorical")
elif kind == "M":
dtypes_list.append("datetime")
if kind in 'fiud':
dtypes_list.append('continuous')
elif kind in 'OSUb':
dtypes_list.append('categorical')
elif kind == 'M':
dtypes_list.append('datetime')
else:
error = (
f"Unsupported data_type for column {column}: {dtype}"
f'Unsupported data_type for column {column}: {dtype}'
)
raise ValueError(error)

Expand Down Expand Up @@ -147,18 +147,18 @@ def fit(
"""
if not entity_columns and segment_size is None:
raise TypeError(
"If the data has no `entity_columns`, `segment_size` must be given."
'If the data has no `entity_columns`, `segment_size` must be given.'
)
if segment_size is not None and not isinstance(segment_size, int):
if sequence_index is None:
raise TypeError(
"`segment_size` must be of type `int` if "
"no `sequence_index` is given."
'`segment_size` must be of type `int` if '
'no `sequence_index` is given.'
)
if data[sequence_index].dtype.kind != "M":
if data[sequence_index].dtype.kind != 'M':
raise TypeError(
"`segment_size` must be of type `int` if "
"`sequence_index` is not a `datetime` column."
'`segment_size` must be of type `int` if '
'`sequence_index` is not a `datetime` column.'
)

segment_size = pd.to_timedelta(segment_size)
Expand Down Expand Up @@ -237,7 +237,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
if context is None:
if num_entities is None:
raise TypeError(
"Either context or num_entities must be not None"
'Either context or num_entities must be not None'
)

context = self._context_values.sample(num_entities, replace=True)
Expand Down
80 changes: 40 additions & 40 deletions deepecho/models/basic_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,22 @@ def __init__(
self._hidden_size = hidden_size

if not cuda or not torch.cuda.is_available():
device = "cpu"
device = 'cpu'
elif isinstance(cuda, str):
device = cuda
else:
device = "cuda"
device = 'cuda'

self._device = torch.device(device)
self._verbose = verbose

LOGGER.info("%s instance created", self)
LOGGER.info('%s instance created', self)

def __repr__(self):
"""Return a representation of the class object."""
return (
f"{self.__class__.__name__}(epochs={self._epochs}, latent_size={self._latent_size},"
f"hidden_size={self._hidden_size}, gen_lr={self._gen_lr}, dis_lr={self._dis_lr},"
f'{self.__class__.__name__}(epochs={self._epochs}, latent_size={self._latent_size},'
f'hidden_size={self._hidden_size}, gen_lr={self._gen_lr}, dis_lr={self._dis_lr},'
f"cuda='{self._device}', verbose={self._verbose})"
)

Expand Down Expand Up @@ -221,25 +221,25 @@ def _index_map(columns, types):
mapping = {}
for column, column_type in enumerate(types):
values = columns[column]
if column_type in ("continuous", "count"):
if column_type in ('continuous', 'count'):
mapping[column] = {
"type": column_type,
"min": np.min(values),
"max": np.max(values),
"indices": (dimensions, dimensions + 1),
'type': column_type,
'min': np.min(values),
'max': np.max(values),
'indices': (dimensions, dimensions + 1),
}
dimensions += 2

elif column_type in ("categorical", "ordinal"):
elif column_type in ('categorical', 'ordinal'):
indices = {}
for value in set(values):
indices[value] = dimensions
dimensions += 1

mapping[column] = {"type": column_type, "indices": indices}
mapping[column] = {'type': column_type, 'indices': indices}

else:
raise ValueError(f"Unsupported type: {column_type}")
raise ValueError(f'Unsupported type: {column_type}')

return mapping, dimensions

Expand All @@ -252,7 +252,7 @@ def _analyze_data(self, sequences, context_types, data_types):
- Index map and dimensions for the data.
"""
sequence_lengths = np.array([
len(sequence["data"][0]) for sequence in sequences
len(sequence['data'][0]) for sequence in sequences
])
self._max_sequence_length = np.max(sequence_lengths)
self._fixed_length = (
Expand All @@ -263,7 +263,7 @@ def _analyze_data(self, sequences, context_types, data_types):
context = []
for column in range(len(context_types)):
context.append([
sequence["context"][column] for sequence in sequences
sequence['context'][column] for sequence in sequences
])

self._context_map, self._context_size = self._index_map(
Expand All @@ -274,7 +274,7 @@ def _analyze_data(self, sequences, context_types, data_types):
data = []
for column in range(len(data_types)):
data.append(
sum([sequence["data"][column] for sequence in sequences], [])
sum([sequence['data'][column] for sequence in sequences], [])
)

self._data_map, self._data_size = self._index_map(data, data_types)
Expand All @@ -284,27 +284,27 @@ def _analyze_data(self, sequences, context_types, data_types):
@staticmethod
def _normalize(tensor, value, properties):
"""Normalize the value between 0 and 1 and flag nans."""
value_idx, missing_idx = properties["indices"]
value_idx, missing_idx = properties['indices']
if pd.isnull(value):
tensor[value_idx] = 0.0
tensor[missing_idx] = 1.0
else:
column_min = properties["min"]
column_range = properties["max"] - column_min
column_min = properties['min']
column_range = properties['max'] - column_min
offset = value - column_min
tensor[value_idx] = 2.0 * offset / column_range - 1.0
tensor[missing_idx] = 0.0

@staticmethod
def _denormalize(tensor, row, properties, round_value):
"""Denormalize previously normalized values, setting NaN values if necessary."""
value_idx, missing_idx = properties["indices"]
value_idx, missing_idx = properties['indices']
if tensor[row, 0, missing_idx] > 0.5:
return None

normalized = tensor[row, 0, value_idx].item()
column_min = properties["min"]
column_range = properties["max"] - column_min
column_min = properties['min']
column_range = properties['max'] - column_min

denormalized = (normalized + 1) * column_range / 2.0 + column_min
if round_value:
Expand All @@ -315,14 +315,14 @@ def _denormalize(tensor, row, properties, round_value):
@staticmethod
def _one_hot_encode(tensor, value, properties):
"""Update the index that corresponds to the value to 1.0."""
value_index = properties["indices"][value]
value_index = properties['indices'][value]
tensor[value_index] = 1.0

@staticmethod
def _one_hot_decode(tensor, row, properties):
"""Obtain the category that corresponds to the highest one-hot value."""
max_value = float("-inf")
for category, idx in properties["indices"].items():
max_value = float('-inf')
for category, idx in properties['indices'].items():
value = tensor[row, 0, idx]
if value > max_value:
max_value = value
Expand All @@ -332,10 +332,10 @@ def _one_hot_decode(tensor, row, properties):

def _value_to_tensor(self, tensor, value, properties):
"""Update the tensor according to the value and properties."""
column_type = properties["type"]
if column_type in ("continuous", "count"):
column_type = properties['type']
if column_type in ('continuous', 'count'):
self._normalize(tensor, value, properties)
elif column_type in ("categorical", "ordinal"):
elif column_type in ('categorical', 'ordinal'):
self._one_hot_encode(tensor, value, properties)

else:
Expand Down Expand Up @@ -381,17 +381,17 @@ def _tensor_to_data(self, tensor):

data = [None] * len(self._data_map)
for column, properties in self._data_map.items():
column_type = properties["type"]
column_type = properties['type']

column_data = []
data[column] = column_data
for row in range(sequence_length):
if column_type in ("continuous", "count"):
round_value = column_type == "count"
if column_type in ('continuous', 'count'):
round_value = column_type == 'count'
value = self._denormalize(
tensor, row, properties, round_value=round_value
)
elif column_type in ("categorical", "ordinal"):
elif column_type in ('categorical', 'ordinal'):
value = self._one_hot_decode(tensor, row, properties)
else:
raise ValueError() # Theoretically unreachable
Expand All @@ -414,15 +414,15 @@ def _build_tensor(self, transform, sequences, key, dim):

def _transform(self, data):
for properties in self._data_map.values():
column_type = properties["type"]
if column_type in ("continuous", "count"):
value_idx, missing_idx = properties["indices"]
column_type = properties['type']
if column_type in ('continuous', 'count'):
value_idx, missing_idx = properties['indices']
data[:, :, value_idx] = torch.tanh(data[:, :, value_idx])
data[:, :, missing_idx] = torch.sigmoid(
data[:, :, missing_idx]
)
elif column_type in ("categorical", "ordinal"):
indices = list(properties["indices"].values())
elif column_type in ('categorical', 'ordinal'):
indices = list(properties['indices'].values())
data[:, :, indices] = torch.nn.functional.softmax(
data[:, :, indices]
)
Expand Down Expand Up @@ -548,10 +548,10 @@ def fit_sequences(self, sequences, context_types, data_types):
self._analyze_data(sequences, context_types, data_types)

data = self._build_tensor(
self._data_to_tensor, sequences, "data", dim=1
self._data_to_tensor, sequences, 'data', dim=1
)
context = self._build_tensor(
self._context_to_tensor, sequences, "context", dim=0
self._context_to_tensor, sequences, 'context', dim=0
)
data_context = _expand_context(data, context)

Expand Down Expand Up @@ -580,7 +580,7 @@ def fit_sequences(self, sequences, context_types, data_types):
d_loss = discriminator_score.item()
g_loss = generator_score.item()
iterator.set_description(
f"Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}"
f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}'
)

def sample_sequence(self, context, sequence_length=None):
Expand Down
Loading

0 comments on commit 2473d50

Please sign in to comment.