Skip to content

Commit

Permalink
Update progress bar for PAR fitting (+ save loss values) (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h authored Oct 11, 2023
1 parent 0bea7c4 commit bbf8e74
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
23 changes: 20 additions & 3 deletions deepecho/models/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self, epochs=128, sample_size=1, cuda=True, verbose=True):

self.device = torch.device(device)
self.verbose = verbose
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])

LOGGER.info('%s instance created', self)

Expand Down Expand Up @@ -321,9 +322,13 @@ def fit_sequences(self, sequences, context_types, data_types):
self._model = PARNet(self._data_dims, self._ctx_dims).to(self.device)
optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-3)

iterator = range(self.epochs)
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
if self.verbose:
iterator = tqdm(iterator)
pbar_description = 'Loss ({loss:.3f})'
iterator.set_description(pbar_description.format(loss=0))

# Reset loss_values dataframe
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])

X_padded, seq_len = torch.nn.utils.rnn.pad_packed_sequence(X)
for epoch in iterator:
Expand All @@ -333,8 +338,20 @@ def fit_sequences(self, sequences, context_types, data_types):
optimizer.zero_grad()
loss = self._compute_loss(X_padded[1:, :, :], Y_padded[:-1, :, :], seq_len)
loss.backward()

epoch_loss_df = pd.DataFrame({
'Epoch': [epoch],
'Loss': [loss.item()]
})
if not self.loss_values.empty:
self.loss_values = pd.concat(
[self.loss_values, epoch_loss_df]
).reset_index(drop=True)
else:
self.loss_values = epoch_loss_df

if self.verbose:
iterator.set_description(f'Epoch {epoch +1} | Loss {loss.item()}')
iterator.set_description(pbar_description.format(loss=loss.item()))

optimizer.step()

Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def test_basic(self):
model.fit_sequences(sequences, context_types, data_types)
model.sample_sequence([])

# Assert
assert set(model.loss_values.columns) == {'Epoch', 'Loss'}
assert len(model.loss_values) == 128

def test_conditional(self):
"""Test the ``PARModel`` with conditional sampling."""
sequences = [
Expand All @@ -60,6 +64,10 @@ def test_conditional(self):
model.fit_sequences(sequences, context_types, data_types)
model.sample_sequence([0])

# Assert
assert set(model.loss_values.columns) == {'Epoch', 'Loss'}
assert len(model.loss_values) == 128

def test_mixed(self):
"""Test the ``PARModel`` with mixed input data."""
sequences = [
Expand All @@ -85,6 +93,10 @@ def test_mixed(self):
model.fit_sequences(sequences, context_types, data_types)
model.sample_sequence([0])

# Assert
assert set(model.loss_values.columns) == {'Epoch', 'Loss'}
assert len(model.loss_values) == 128

def test_count(self):
"""Test the PARModel with datatype ``count``."""
sequences = [
Expand All @@ -110,6 +122,10 @@ def test_count(self):
model.fit_sequences(sequences, context_types, data_types)
model.sample_sequence([0])

# Assert
assert set(model.loss_values.columns) == {'Epoch', 'Loss'}
assert len(model.loss_values) == 128

def test_variable_length(self):
"""Test ``PARModel`` with variable data length."""
sequences = [
Expand All @@ -134,3 +150,7 @@ def test_variable_length(self):
model = PARModel()
model.fit_sequences(sequences, context_types, data_types)
model.sample_sequence([0])

# Assert
assert set(model.loss_values.columns) == {'Epoch', 'Loss'}
assert len(model.loss_values) == 128

0 comments on commit bbf8e74

Please sign in to comment.