Skip to content

Commit

Permalink
Merge pull request #59 from tsugumi-sys/fix/var-annotated-mypy
Browse files Browse the repository at this point in the history
fix var-annotated mypy error
  • Loading branch information
tsugumi-sys authored Jan 7, 2024
2 parents cce82c5 + 3fe7aaa commit c7656a9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
53 changes: 33 additions & 20 deletions pipelines/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional, TypedDict, cast

import pandas as pd
import torch
Expand All @@ -13,6 +13,12 @@
from pipelines.utils.early_stopping import EarlyStopping


class TrainingMetrics(TypedDict):
train_loss: List[float]
validation_loss: List[float]
validation_accuracy: List[float]


class Trainer(BaseRunner):
def __init__(
self,
Expand Down Expand Up @@ -41,7 +47,7 @@ def __init__(
if not metrics_filename.endswith(".csv"):
raise ValueError("`save_metrics_filename` should be end with `.csv`")
self.metrics_filename = metrics_filename
self._training_metrics = {
self._training_metrics: TrainingMetrics = {
"train_loss": [],
"validation_loss": [],
"validation_accuracy": [],
Expand All @@ -51,26 +57,26 @@ def run(self) -> None:
for epoch in range(1, self.train_epochs + 1):
self.__train()
self.__validation()
training_metrics = self.__latest_training_metrics()
training_metric = self.__latest_training_metric()
if epoch % 10 == 0:
print(
f"Epoch: {epoch}, Training loss: "
"{.8f}, Validation loss: {.8f}, Validation Accuracy: {.8f}".format(
training_metrics["train_loss"],
training_metrics["validation_loss"],
training_metrics["validation_accuracy"],
training_metric["train_loss"],
training_metric["validation_loss"],
training_metric["validation_accuracy"],
)
)

self.early_stopping(training_metrics["validation_loss"], self.model)
self.early_stopping(training_metric["validation_loss"], self.model)
if self.early_stopping.early_stop is True:
print(f"Early stopped at epoch {epoch}")
break

self.__save_metrics()

@property
def training_metrics(self) -> Dict[str, List[float]]:
def training_metrics(self) -> TrainingMetrics:
return self._training_metrics

def __train(self):
Expand All @@ -88,7 +94,7 @@ def __train(self):

train_loss += loss.item()

self.__log_metrics({"train_loss": train_loss / len(self.train_dataloader)})
self.__log_metric(train_loss=train_loss / len(self.train_dataloader))

def __validation(self):
valid_loss, valid_acc = 0, 0
Expand All @@ -102,19 +108,26 @@ def __validation(self):
valid_loss += loss.item()
valid_acc += acc.item()
dataset_length = len(self.valid_dataloader)
self.__log_metrics(
{
"validation_loss": valid_loss / dataset_length,
"validation_accuracy": valid_acc / dataset_length,
}
self.__log_metric(
validation_loss=valid_loss / dataset_length,
validation_accuracy=valid_acc / dataset_length,
)

def __log_metrics(self, res: Dict[str, float]):
for key, val in res.items():
self._training_metrics[key].append(val)

def __latest_training_metrics(self) -> Dict[str, float]:
return {k: v[-1] for k, v in self._training_metrics.items()}
def __log_metric(
self,
train_loss: Optional[float] = None,
validation_loss: Optional[float] = None,
validation_accuracy: Optional[float] = None,
):
if train_loss is not None:
self._training_metrics["train_loss"].append(train_loss)
if validation_loss is not None:
self._training_metrics["validation_loss"].append(validation_loss)
if validation_accuracy is not None:
self._training_metrics["validation_accuracy"].append(validation_accuracy)

def __latest_training_metric(self) -> Dict[str, float]:
return {k: cast(List[float], v)[-1] for k, v in self._training_metrics.items()}

def __save_metrics(self) -> None:
pd.DataFrame(self._training_metrics).to_csv(
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,3 @@ split-on-trailing-comma = true
namespace_packages = true
ignore_missing_imports = true
python_version = "3.11"
disable_error_code = ["var-annotated"]

0 comments on commit c7656a9

Please sign in to comment.