Skip to content

Training & Testing Details

Junyong Lee edited this page Feb 4, 2022 · 2 revisions

Training & testing the network

Training

# multi GPU (with DistributedDataParallel) example
CUDA_VISIBLE_DEVICES=0,1,2,3 python -B -m torch.distributed.launch --nproc_per_node=4 --master_port=9000 run.py \
            --is_train \
            --mode IFAN \
            --config config_IFAN \
            --trainer trainer \
            --network IFAN \
            -b 2 \
            -th 8 \
            -dl \
            -ss \
            -dist

# resuming example (trainer will load checkpoint saved after 100 epoch, training will resume from 101 epoch)
CUDA_VISIBLE_DEVICES=0,1,2,3 python -B -m torch.distributed.launch --nproc_per_node=4 --master_port=9000 run.py \
            ... \
            -th 8 \
            -r 100 \
            -ss \
            -dist

# single GPU (with DataParallel) example
CUDA_VISIBLE_DEVICES=0 python -B run.py \
            ... \
            -ss

Note:

  • The image loss (MSE) will be applied no matter what.
  • If IFAN is included in [mode], it will trigger both disparity and reblurring losses.
  • To separately apply each of the disparity and the reblurring losses, do not include IFAN in [mode] but trigger each loss by including D(for the disparity loss) or R(for the reblurring loss) in [mode] (e.g., --mode my_net_D).
  • To train a network that takes dual-pixel stereo images as an input, dual should be included in the option [mode], and IFAN_dual should be specified for the option [network].
  • Options
    • --is_train: If it is specified, run.py will train the network. Default: False
    • --mode: The name of a model to train. The logging folder named with the [mode] will be created as [LOG_ROOT]/IFAN_CVPR2021/[mode]/. Default: IFAN
    • --config: The name of a config file located as ./config/[config].py. Default: None, and the default should not be changed.
    • --trainer: The name of a trainer file located as ./models/trainers/[trainer].py. Default: trainer
    • --network: The name of a network file located as ./models/archs/[network].py. Default: IFAN
    • -b, --batch_size: The batch size. For the multi GPUs (DistributedDataParallel), the total batch size will be, nproc_per_node * b. Default: 8
    • -th, --thread_num: The number of threads (num_workers) for the data loader. Default: 8
    • -dl, --delete_log: The option whether to delete logs under [mode] (i.e., [LOG_ROOT]/IFAN_CVPR2021/[mode]/*). The option works only when --is_train is specified. Default: False
    • -r, --resume: Resume training with the specified epoch (e.g., -r 100). Note that -dl should not be specified with this option.
    • -ss, --save_sample: Save sample images for both training and testing. Images will be saved in [LOG_ROOT]/RefVSR_CVPR2021/[mode]/sample/. Default: False
    • -dist: Enables multi-processing with DistributedDataParallel. Default: False

Testing

CUDA_VISIBLE_DEVICES=0 python run.py --mode [mode] --data [DATASET]
# e.g., CUDA_VISIBLE_DEVICES=0 python run.py --mode IFAN --data DPDD

Note:

  • Specify only [mode] of the trained model. [config] doesn't have to be specified, as it will be automatically loaded.
  • Testing results will be saved in [LOG_ROOT]/IFAN_CVPR2021/[mode]/result/quanti_quali/[mode]_[epoch]/[data]/.
  • Options
    • --mode: The name of a model to test.
    • --data: The name of a dataset for evaluation: DPDD | RealDOF | CUHK | PixelDP | random. Default: DPDD
      • The data structure can be modified by the function set_eval_path(..) in ./configs/config.py.
      • random is for testing models with any images, which should be placed as [DATASET_ROOT]/random/*.[jpg|png].
    • -ckpt_name: Loads the checkpoint with the name of the checkpoint under [LOG_ROOT]/IFAN_CVPR2021/[mode]/checkpoint/train/epoch/ckpt/ (e.g., python run.py --mode IFAN --data DPDD --ckpt_name IFAN_00100.pytorch).
    • -ckpt_abs_name. Loads the checkpoint of the absolute path (e.g., python run.py --mode IFAN --data DPDD --ckpt_abs_name ./ckpt/IFAN.pytorch).
    • -ckpt_epoch: Loads the checkpoint of the specified epoch (e.g., python run.py --mode IFAN --data DPDD --ckpt_epoch 100).
    • -ckpt_sc: Loads the checkpoint with the best validation score (e.g., python run.py --mode IFAN --data DPDD --ckpt_sc)
Clone this wiki locally