Skip to content

Commit

Permalink
increased line length
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Mar 25, 2024
1 parent e4957e7 commit 42f595f
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 203 deletions.
4 changes: 1 addition & 3 deletions deepecho/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,4 @@

def load_demo():
"""Load the demo DataFrame."""
return pd.read_csv(
os.path.join(_DATA_PATH, 'demo.csv'), parse_dates=['date']
)
return pd.read_csv(os.path.join(_DATA_PATH, 'demo.csv'), parse_dates=['date'])
19 changes: 5 additions & 14 deletions deepecho/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def _get_data_types(data, data_types, columns):
elif kind == 'M':
dtypes_list.append('datetime')
else:
error = (
f'Unsupported data_type for column {column}: {dtype}'
)
error = f'Unsupported data_type for column {column}: {dtype}'
raise ValueError(error)

return dtypes_list
Expand Down Expand Up @@ -146,14 +144,11 @@ def fit(
such as integer values or datetimes.
"""
if not entity_columns and segment_size is None:
raise TypeError(
'If the data has no `entity_columns`, `segment_size` must be given.'
)
raise TypeError('If the data has no `entity_columns`, `segment_size` must be given.')
if segment_size is not None and not isinstance(segment_size, int):
if sequence_index is None:
raise TypeError(
'`segment_size` must be of type `int` if '
'no `sequence_index` is given.'
'`segment_size` must be of type `int` if ' 'no `sequence_index` is given.'
)
if data[sequence_index].dtype.kind != 'M':
raise TypeError(
Expand All @@ -176,9 +171,7 @@ def fit(
self._data_columns.remove(sequence_index)

data_types = self._get_data_types(data, data_types, self._data_columns)
context_types = self._get_data_types(
data, data_types, self._context_columns
)
context_types = self._get_data_types(data, data_types, self._context_columns)
sequences = assemble_sequences(
data,
self._entity_columns,
Expand Down Expand Up @@ -236,9 +229,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
"""
if context is None:
if num_entities is None:
raise TypeError(
'Either context or num_entities must be not None'
)
raise TypeError('Either context or num_entities must be not None')

context = self._context_values.sample(num_entities, replace=True)
context = context.reset_index(drop=True)
Expand Down
72 changes: 18 additions & 54 deletions deepecho/models/basic_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def _expand_context(data, context):
return torch.cat(
[
data,
context.unsqueeze(0).expand(
data.shape[0], context.shape[0], context.shape[1]
),
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1]),
],
dim=2,
)
Expand Down Expand Up @@ -52,9 +50,7 @@ class BasicGenerator(torch.nn.Module):
Device to which this Module is associated to.
"""

def __init__(
self, context_size, latent_size, hidden_size, data_size, device
):
def __init__(self, context_size, latent_size, hidden_size, data_size, device):
super().__init__()
self.latent_size = latent_size
self.rnn = torch.nn.GRU(context_size + latent_size, hidden_size)
Expand Down Expand Up @@ -251,31 +247,21 @@ def _analyze_data(self, sequences, context_types, data_types):
- Index map and dimensions for the context.
- Index map and dimensions for the data.
"""
sequence_lengths = np.array([
len(sequence['data'][0]) for sequence in sequences
])
sequence_lengths = np.array([len(sequence['data'][0]) for sequence in sequences])
self._max_sequence_length = np.max(sequence_lengths)
self._fixed_length = (
sequence_lengths == self._max_sequence_length
).all()
self._fixed_length = (sequence_lengths == self._max_sequence_length).all()

# Concatenate all the context sequences together
context = []
for column in range(len(context_types)):
context.append([
sequence['context'][column] for sequence in sequences
])
context.append([sequence['context'][column] for sequence in sequences])

self._context_map, self._context_size = self._index_map(
context, context_types
)
self._context_map, self._context_size = self._index_map(context, context_types)

# Concatenate all the data sequences together
data = []
for column in range(len(data_types)):
data.append(
sum([sequence['data'][column] for sequence in sequences], [])
)
data.append(sum([sequence['data'][column] for sequence in sequences], []))

self._data_map, self._data_size = self._index_map(data, data_types)

Expand Down Expand Up @@ -388,9 +374,7 @@ def _tensor_to_data(self, tensor):
for row in range(sequence_length):
if column_type in ('continuous', 'count'):
round_value = column_type == 'count'
value = self._denormalize(
tensor, row, properties, round_value=round_value
)
value = self._denormalize(tensor, row, properties, round_value=round_value)
elif column_type in ('categorical', 'ordinal'):
value = self._one_hot_decode(tensor, row, properties)
else:
Expand Down Expand Up @@ -418,14 +402,10 @@ def _transform(self, data):
if column_type in ('continuous', 'count'):
value_idx, missing_idx = properties['indices']
data[:, :, value_idx] = torch.tanh(data[:, :, value_idx])
data[:, :, missing_idx] = torch.sigmoid(
data[:, :, missing_idx]
)
data[:, :, missing_idx] = torch.sigmoid(data[:, :, missing_idx])
elif column_type in ('categorical', 'ordinal'):
indices = list(properties['indices'].values())
data[:, :, indices] = torch.nn.functional.softmax(
data[:, :, indices]
)
data[:, :, indices] = torch.nn.functional.softmax(data[:, :, indices])

return data

Expand Down Expand Up @@ -454,9 +434,7 @@ def _generate(self, context, sequence_length=None):

return generated

def _discriminator_step(
self, discriminator, discriminator_opt, data_context, context
):
def _discriminator_step(self, discriminator, discriminator_opt, data_context, context):
real_scores = discriminator(data_context)

fake = self._generate(context)
Expand Down Expand Up @@ -500,12 +478,8 @@ def _build_fit_artifacts(self):
hidden_size=self._hidden_size,
).to(self._device)

generator_opt = torch.optim.Adam(
self._generator.parameters(), lr=self._gen_lr
)
discriminator_opt = torch.optim.Adam(
discriminator.parameters(), lr=self._dis_lr
)
generator_opt = torch.optim.Adam(self._generator.parameters(), lr=self._gen_lr)
discriminator_opt = torch.optim.Adam(discriminator.parameters(), lr=self._dis_lr)

return discriminator, generator_opt, discriminator_opt

Expand Down Expand Up @@ -547,17 +521,11 @@ def fit_sequences(self, sequences, context_types, data_types):
"""
self._analyze_data(sequences, context_types, data_types)

data = self._build_tensor(
self._data_to_tensor, sequences, 'data', dim=1
)
context = self._build_tensor(
self._context_to_tensor, sequences, 'context', dim=0
)
data = self._build_tensor(self._data_to_tensor, sequences, 'data', dim=1)
context = self._build_tensor(self._context_to_tensor, sequences, 'context', dim=0)
data_context = _expand_context(data, context)

discriminator, generator_opt, discriminator_opt = (
self._build_fit_artifacts()
)
discriminator, generator_opt, discriminator_opt = self._build_fit_artifacts()

iterator = range(self._epochs)
if self._verbose:
Expand All @@ -579,9 +547,7 @@ def fit_sequences(self, sequences, context_types, data_types):
if self._verbose:
d_loss = discriminator_score.item()
g_loss = generator_score.item()
iterator.set_description(
f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}'
)
iterator.set_description(f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}')

def sample_sequence(self, context, sequence_length=None):
"""Sample a single sequence conditioned on context.
Expand All @@ -596,9 +562,7 @@ def sample_sequence(self, context, sequence_length=None):
A list of lists (data) corresponding to the types specified
in data_types when fit was called.
"""
context = (
self._context_to_tensor(context).unsqueeze(0).to(self._device)
)
context = self._context_to_tensor(context).unsqueeze(0).to(self._device)

with torch.no_grad():
generated = self._generate(context, sequence_length)
Expand Down
Loading

0 comments on commit 42f595f

Please sign in to comment.