From 3fe7aaaff03b2b9d378b3b4dcb5a76d53a5e3731 Mon Sep 17 00:00:00 2001 From: tsugumi-sys Date: Mon, 8 Jan 2024 01:15:48 +0900 Subject: [PATCH] fix var-annotated mypy error --- pipelines/trainer.py | 53 +++++++++++++++++++++++++++----------------- pyproject.toml | 1 - 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/pipelines/trainer.py b/pipelines/trainer.py index ecf06bb..20afaaa 100644 --- a/pipelines/trainer.py +++ b/pipelines/trainer.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional, TypedDict, cast import pandas as pd import torch @@ -13,6 +13,12 @@ from pipelines.utils.early_stopping import EarlyStopping +class TrainingMetrics(TypedDict): + train_loss: List[float] + validation_loss: List[float] + validation_accuracy: List[float] + + class Trainer(BaseRunner): def __init__( self, @@ -41,7 +47,7 @@ def __init__( if not metrics_filename.endswith(".csv"): raise ValueError("`save_metrics_filename` should be end with `.csv`") self.metrics_filename = metrics_filename - self._training_metrics = { + self._training_metrics: TrainingMetrics = { "train_loss": [], "validation_loss": [], "validation_accuracy": [], @@ -51,18 +57,18 @@ def run(self) -> None: for epoch in range(1, self.train_epochs + 1): self.__train() self.__validation() - training_metrics = self.__latest_training_metrics() + training_metric = self.__latest_training_metric() if epoch % 10 == 0: print( f"Epoch: {epoch}, Training loss: " "{.8f}, Validation loss: {.8f}, Validation Accuracy: {.8f}".format( - training_metrics["train_loss"], - training_metrics["validation_loss"], - training_metrics["validation_accuracy"], + training_metric["train_loss"], + training_metric["validation_loss"], + training_metric["validation_accuracy"], ) ) - self.early_stopping(training_metrics["validation_loss"], self.model) + self.early_stopping(training_metric["validation_loss"], self.model) if self.early_stopping.early_stop is True: print(f"Early stopped at epoch {epoch}") break @@ -70,7 +76,7 @@ def run(self) -> None: self.__save_metrics() @property - def training_metrics(self) -> Dict[str, List[float]]: + def training_metrics(self) -> TrainingMetrics: return self._training_metrics def __train(self): @@ -88,7 +94,7 @@ def __train(self): train_loss += loss.item() - self.__log_metrics({"train_loss": train_loss / len(self.train_dataloader)}) + self.__log_metric(train_loss=train_loss / len(self.train_dataloader)) def __validation(self): valid_loss, valid_acc = 0, 0 @@ -102,19 +108,26 @@ def __validation(self): valid_loss += loss.item() valid_acc += acc.item() dataset_length = len(self.valid_dataloader) - self.__log_metrics( - { - "validation_loss": valid_loss / dataset_length, - "validation_accuracy": valid_acc / dataset_length, - } + self.__log_metric( + validation_loss=valid_loss / dataset_length, + validation_accuracy=valid_acc / dataset_length, ) - def __log_metrics(self, res: Dict[str, float]): - for key, val in res.items(): - self._training_metrics[key].append(val) - - def __latest_training_metrics(self) -> Dict[str, float]: - return {k: v[-1] for k, v in self._training_metrics.items()} + def __log_metric( + self, + train_loss: Optional[float] = None, + validation_loss: Optional[float] = None, + validation_accuracy: Optional[float] = None, + ): + if train_loss is not None: + self._training_metrics["train_loss"].append(train_loss) + if validation_loss is not None: + self._training_metrics["validation_loss"].append(validation_loss) + if validation_accuracy is not None: + self._training_metrics["validation_accuracy"].append(validation_accuracy) + + def __latest_training_metric(self) -> Dict[str, float]: + return {k: cast(List[float], v)[-1] for k, v in self._training_metrics.items()} def __save_metrics(self) -> None: pd.DataFrame(self._training_metrics).to_csv( diff --git a/pyproject.toml b/pyproject.toml index 2530b4a..494f692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,4 +30,3 @@ split-on-trailing-comma = true namespace_packages = true ignore_missing_imports = true python_version = "3.11" -disable_error_code = ["var-annotated"]