Skip to content

Commit

Permalink
add test script
Browse files Browse the repository at this point in the history
  • Loading branch information
westny committed May 17, 2024
1 parent 7fc5f01 commit 7e637dd
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
101 changes: 101 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024, Theodor Westny. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import warnings
import torch

from torch.multiprocessing import set_sharing_strategy
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import Logger, CSVLogger

from arguments import args
from preamble import load_config, import_from_module

torch.set_float32_matmul_precision('medium')
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*Checkpoint directory*")

set_sharing_strategy('file_system')

# Load configuration and import modules
config = load_config(args.config)
TorchModel = import_from_module(config["model"]["module"], config["model"]["class"])
LitDataModule = import_from_module(config["datamodule"]["module"], config["datamodule"]["class"])
LitModel = import_from_module(config["litmodule"]["module"], config["litmodule"]["class"])


def main(save_name: str) -> None:
ds = config["dataset"]
path = os.path.join("saved_models", ds, save_name)

# Check if checkpoint exists
if os.path.exists(path + ".ckpt"):
ckpt = path + ".ckpt"
elif os.path.exists(path + "-v1.ckpt"):
ckpt = path + "-v1.ckpt"
else:
raise NameError(f"Could not find model with name: {save_name}")

# Determine the number of devices, and accelerator
if torch.cuda.is_available() and args.use_cuda:
devices, accelerator = -1, "auto"
else:
devices, accelerator = 1, "cpu"

# Setup logger
logger: bool | Logger
if args.dry_run:
logger = False
args.small_ds = True
elif not args.use_logger:
logger = False
else:
logger = CSVLogger(save_dir=os.path.join("lightning_logs", ds), name=save_name)

# Setup model
net = TorchModel(config["model"])
model = LitModel(net, config["training"])

# Load checkpoint into model
ckpt_dict = torch.load(ckpt)
model.load_state_dict(ckpt_dict["state_dict"], strict=False)

# Setup datamodule
if args.root:
config["datamodule"]["root"] = args.root
datamodule = LitDataModule(config["datamodule"], args)

# Setup trainer
trainer = Trainer(accelerator=accelerator, devices=devices, logger=logger)

# Start testing
trainer.test(model, datamodule=datamodule, verbose=True)


if __name__ == "__main__":
seed_everything(args.seed, workers=True)

ds_name = config["dataset"]
mdl_name = config["model"]["class"]
add_name = f"-{args.add_name}" if args.add_name else ""

full_save_name = f"{mdl_name}{add_name}-{ds_name}"

print('----------------------------------------------------')
print(f'\nGetting ready to test model: {full_save_name} \n')
print('----------------------------------------------------')

main(full_save_name)
18 changes: 13 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import time
import pathlib
import warnings
Expand Down Expand Up @@ -42,7 +42,7 @@

def main(save_name: str) -> None:
ds = config["dataset"]
ckpt_path = pathlib.Path(f"saved_models/{ds}/{save_name}.ckpt")
ckpt_path = pathlib.Path(os.path.join("saved_models", ds, f"{save_name}.ckpt"))

# Check if checkpoint exists and the overwrite flag is not set
if ckpt_path.exists() and not args.overwrite:
Expand Down Expand Up @@ -88,15 +88,16 @@ def main(save_name: str) -> None:
run_name = f"{save_name}_{time.strftime('%d-%m_%H:%M:%S')}"
logger = WandbLogger(project="dronalize", name=run_name)

# Setup model, datamodule and trainer
# Setup model
net = TorchModel(config["model"])
model = LitModel(net, config["training"])

# Setup datamodule
if args.root:
config["datamodule"]["root"] = args.root

datamodule = LitDataModule(config["datamodule"], args)

# Setup trainer
trainer = Trainer(max_epochs=config["training"]["epochs"],
logger=logger,
devices=devices,
Expand All @@ -113,8 +114,15 @@ def main(save_name: str) -> None:
if __name__ == "__main__":
seed_everything(args.seed, workers=True)

mdl_name = config["model"]["class"]
ds_name = config["dataset"]
full_save_name = f"Example{args.add_name}-{ds_name}"
add_name = f"-{args.add_name}" if args.add_name else ""

full_save_name = f"{mdl_name}{add_name}-{ds_name}"

if args.dry_run:
full_save_name += "-DEBUG"

print('----------------------------------------------------')
print(f'\nGetting ready to train model: {full_save_name} \n')
print('----------------------------------------------------')
Expand Down

0 comments on commit 7e637dd

Please sign in to comment.