From 7e637dd2d33a1a9a64ebd1fb3bb15d92c662ba5b Mon Sep 17 00:00:00 2001 From: Theo West Date: Fri, 17 May 2024 12:56:22 +0200 Subject: [PATCH] add test script --- test.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ train.py | 18 +++++++--- 2 files changed, 114 insertions(+), 5 deletions(-) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 0000000..e43cb2f --- /dev/null +++ b/test.py @@ -0,0 +1,101 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import warnings +import torch + +from torch.multiprocessing import set_sharing_strategy +from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.loggers import Logger, CSVLogger + +from arguments import args +from preamble import load_config, import_from_module + +torch.set_float32_matmul_precision('medium') +warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") +warnings.filterwarnings("ignore", ".*Checkpoint directory*") + +set_sharing_strategy('file_system') + +# Load configuration and import modules +config = load_config(args.config) +TorchModel = import_from_module(config["model"]["module"], config["model"]["class"]) +LitDataModule = import_from_module(config["datamodule"]["module"], config["datamodule"]["class"]) +LitModel = import_from_module(config["litmodule"]["module"], config["litmodule"]["class"]) + + +def main(save_name: str) -> None: + ds = config["dataset"] + path = os.path.join("saved_models", ds, save_name) + + # Check if checkpoint exists + if os.path.exists(path + ".ckpt"): + ckpt = path + ".ckpt" + elif os.path.exists(path + "-v1.ckpt"): + ckpt = path + "-v1.ckpt" + else: + raise NameError(f"Could not find model with name: {save_name}") + + # Determine the number of devices, and accelerator + if torch.cuda.is_available() and args.use_cuda: + devices, accelerator = -1, "auto" + else: + devices, accelerator = 1, "cpu" + + # Setup logger + logger: bool | Logger + if args.dry_run: + logger = False + args.small_ds = True + elif not args.use_logger: + logger = False + else: + logger = CSVLogger(save_dir=os.path.join("lightning_logs", ds), name=save_name) + + # Setup model + net = TorchModel(config["model"]) + model = LitModel(net, config["training"]) + + # Load checkpoint into model + ckpt_dict = torch.load(ckpt) + model.load_state_dict(ckpt_dict["state_dict"], strict=False) + + # Setup datamodule + if args.root: + config["datamodule"]["root"] = args.root + datamodule = LitDataModule(config["datamodule"], args) + + # Setup trainer + trainer = Trainer(accelerator=accelerator, devices=devices, logger=logger) + + # Start testing + trainer.test(model, datamodule=datamodule, verbose=True) + + +if __name__ == "__main__": + seed_everything(args.seed, workers=True) + + ds_name = config["dataset"] + mdl_name = config["model"]["class"] + add_name = f"-{args.add_name}" if args.add_name else "" + + full_save_name = f"{mdl_name}{add_name}-{ds_name}" + + print('----------------------------------------------------') + print(f'\nGetting ready to test model: {full_save_name} \n') + print('----------------------------------------------------') + + main(full_save_name) diff --git a/train.py b/train.py index 0b7c93c..8b7916c 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import os import time import pathlib import warnings @@ -42,7 +42,7 @@ def main(save_name: str) -> None: ds = config["dataset"] - ckpt_path = pathlib.Path(f"saved_models/{ds}/{save_name}.ckpt") + ckpt_path = pathlib.Path(os.path.join("saved_models", ds, f"{save_name}.ckpt")) # Check if checkpoint exists and the overwrite flag is not set if ckpt_path.exists() and not args.overwrite: @@ -88,15 +88,16 @@ def main(save_name: str) -> None: run_name = f"{save_name}_{time.strftime('%d-%m_%H:%M:%S')}" logger = WandbLogger(project="dronalize", name=run_name) - # Setup model, datamodule and trainer + # Setup model net = TorchModel(config["model"]) model = LitModel(net, config["training"]) + # Setup datamodule if args.root: config["datamodule"]["root"] = args.root - datamodule = LitDataModule(config["datamodule"], args) + # Setup trainer trainer = Trainer(max_epochs=config["training"]["epochs"], logger=logger, devices=devices, @@ -113,8 +114,15 @@ def main(save_name: str) -> None: if __name__ == "__main__": seed_everything(args.seed, workers=True) + mdl_name = config["model"]["class"] ds_name = config["dataset"] - full_save_name = f"Example{args.add_name}-{ds_name}" + add_name = f"-{args.add_name}" if args.add_name else "" + + full_save_name = f"{mdl_name}{add_name}-{ds_name}" + + if args.dry_run: + full_save_name += "-DEBUG" + print('----------------------------------------------------') print(f'\nGetting ready to train model: {full_save_name} \n') print('----------------------------------------------------')