Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ruff for linting, remove flake8, remove isort, remove pylint #94

Merged
merged 7 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,17 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a

# LINT TARGETS

.PHONY: lint-deepecho
lint-deepecho: ## check style with flake8 and isort
flake8 deepecho
isort -c --recursive deepecho
pylint deepecho --rcfile=setup.cfg

.PHONY: lint-tests
lint-tests: ## check style with flake8 and isort
flake8 --ignore=D tests
isort -c --recursive tests

.PHONY: lint
lint: ## Run all code style checks
invoke lint
lint:
ruff check .
ruff format . --check

.PHONY: fix-lint
fix-lint: ## fix lint issues using autoflake, autopep8, and isort
find deepecho tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
autopep8 --in-place --recursive --aggressive deepecho tests
isort --apply --atomic --recursive deepecho tests

fix-lint:
ruff check --fix .
ruff format .

# TEST TARGETS

.PHONY: test-unit
test-unit: ## run unit tests using pytest
invoke unit
Expand Down
33 changes: 25 additions & 8 deletions deepecho/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from deepecho.sequences import assemble_sequences


class DeepEcho():
class DeepEcho:
"""The base class for DeepEcho models."""

_verbose = True
Expand All @@ -28,7 +28,13 @@ def _validate(sequences, context_types, data_types):
data_types:
See `fit`.
"""
dtypes = set(['continuous', 'categorical', 'ordinal', 'count', 'datetime'])
dtypes = set([
'continuous',
'categorical',
'ordinal',
'count',
'datetime',
])
assert all(dtype in dtypes for dtype in context_types)
assert all(dtype in dtypes for dtype in data_types)

Expand Down Expand Up @@ -99,8 +105,15 @@ def _get_data_types(data, data_types, columns):

return dtypes_list

def fit(self, data, entity_columns=None, context_columns=None,
data_types=None, segment_size=None, sequence_index=None):
def fit(
self,
data,
entity_columns=None,
context_columns=None,
data_types=None,
segment_size=None,
sequence_index=None,
):
"""Fit the model to a dataframe containing time series data.

Args:
Expand Down Expand Up @@ -135,8 +148,7 @@ def fit(self, data, entity_columns=None, context_columns=None,
if segment_size is not None and not isinstance(segment_size, int):
if sequence_index is None:
raise TypeError(
'`segment_size` must be of type `int` if '
'no `sequence_index` is given.'
'`segment_size` must be of type `int` if ' 'no `sequence_index` is given.'
)
if data[sequence_index].dtype.kind != 'M':
raise TypeError(
Expand All @@ -161,7 +173,12 @@ def fit(self, data, entity_columns=None, context_columns=None,
data_types = self._get_data_types(data, data_types, self._data_columns)
context_types = self._get_data_types(data, data_types, self._context_columns)
sequences = assemble_sequences(
data, self._entity_columns, self._context_columns, segment_size, sequence_index)
data,
self._entity_columns,
self._context_columns,
segment_size,
sequence_index,
)

# Validate and fit
self._validate(sequences, context_types, data_types)
Expand Down Expand Up @@ -242,7 +259,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
# Reformat as a DataFrame
group = pd.DataFrame(
dict(zip(self._data_columns, sequence)),
columns=self._data_columns
columns=self._data_columns,
)
group[self._entity_columns] = entity_values
for column, value in zip(self._context_columns, context_values):
Expand Down
38 changes: 23 additions & 15 deletions deepecho/models/basic_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@


def _expand_context(data, context):
return torch.cat([
data,
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1])
], dim=2)
return torch.cat(
[
data,
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1]),
],
dim=2,
)


class BasicGenerator(torch.nn.Module):
Expand Down Expand Up @@ -65,7 +68,7 @@ def forward(self, context=None, sequence_length=None):
"""
latent = torch.randn(
size=(sequence_length, context.size(0), self.latent_size),
device=self.device
device=self.device,
)
latent = _expand_context(latent, context)

Expand Down Expand Up @@ -150,8 +153,16 @@ class BasicGANModel(DeepEcho):
_model_data_size = None
_generator = None

def __init__(self, epochs=1024, latent_size=32, hidden_size=16,
gen_lr=1e-3, dis_lr=1e-3, cuda=True, verbose=True):
def __init__(
self,
epochs=1024,
latent_size=32,
hidden_size=16,
gen_lr=1e-3,
dis_lr=1e-3,
cuda=True,
verbose=True,
):
self._epochs = epochs
self._gen_lr = gen_lr
self._dis_lr = dis_lr
Expand Down Expand Up @@ -211,7 +222,7 @@ def _index_map(columns, types):
'type': column_type,
'min': np.min(values),
'max': np.max(values),
'indices': (dimensions, dimensions + 1)
'indices': (dimensions, dimensions + 1),
}
dimensions += 2

Expand All @@ -221,10 +232,7 @@ def _index_map(columns, types):
indices[value] = dimensions
dimensions += 1

mapping[column] = {
'type': column_type,
'indices': indices
}
mapping[column] = {'type': column_type, 'indices': indices}

else:
raise ValueError(f'Unsupported type: {column_type}')
Expand Down Expand Up @@ -317,7 +325,7 @@ def _value_to_tensor(self, tensor, value, properties):
self._one_hot_encode(tensor, value, properties)

else:
raise ValueError() # Theoretically unreachable
raise ValueError() # Theoretically unreachable

def _data_to_tensor(self, data):
"""Convert the input data to the corresponding tensor.
Expand Down Expand Up @@ -370,7 +378,7 @@ def _tensor_to_data(self, tensor):
elif column_type in ('categorical', 'ordinal'):
value = self._one_hot_decode(tensor, row, properties)
else:
raise ValueError() # Theoretically unreachable
raise ValueError() # Theoretically unreachable

column_data.append(value)

Expand Down Expand Up @@ -412,7 +420,7 @@ def _truncate(self, generated):
end_flag = sequence[:, self._data_size]
if (end_flag == 1.0).any():
cut_idx = end_flag.detach().cpu().numpy().argmax()
sequence[cut_idx + 1:] = 0.0
sequence[cut_idx + 1 :] = 0.0

def _generate(self, context, sequence_length=None):
generated = self._generator(
Expand Down
Loading
Loading