Skip to content

Commit

Permalink
reset loss_values df on refit
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Oct 11, 2023
1 parent 44e8881 commit 199f932
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions deepecho/models/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def fit_sequences(self, sequences, context_types, data_types):
pbar_description = 'Loss ({loss:.3f})'
iterator.set_description(pbar_description.format(loss=0))

# Reset loss_values dataframe
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])

X_padded, seq_len = torch.nn.utils.rnn.pad_packed_sequence(X)
for epoch in iterator:
Y = self._model(X, C)
Expand Down

0 comments on commit 199f932

Please sign in to comment.