From e311f964fdb3888df2d8253e633c59ca9b4470e0 Mon Sep 17 00:00:00 2001 From: Dylan Flaute Date: Thu, 11 Jul 2019 11:33:37 -0400 Subject: [PATCH 1/2] Doc multigpu and propagate data path. --- references/detection/train.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index c8fb01ab1da..5f928859adf 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -1,3 +1,11 @@ +"""PyTorch Detection Training. + +To run in a multi-gpu environment, use the distributed launcher:: + + python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \ + train.py ... --world-size $NGPU + +""" import datetime import os import time @@ -18,10 +26,10 @@ import transforms as T -def get_dataset(name, image_set, transform): +def get_dataset(name, image_set, transform, data_path): paths = { - "coco": ('/datasets01/COCO/022719/', get_coco, 91), - "coco_kp": ('/datasets01/COCO/022719/', get_coco_kp, 2) + "coco": (data_path, get_coco, 91), + "coco_kp": (data_path, get_coco_kp, 2) } p, ds_fn, num_classes = paths[name] @@ -46,8 +54,8 @@ def main(args): # Data loading code print("Loading data") - dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) - dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False)) + dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path) + dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path) print("Creating data loaders") if args.distributed: @@ -125,7 +133,8 @@ def main(args): if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='PyTorch Detection Training') + parser = argparse.ArgumentParser( + description=__doc__) parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset') parser.add_argument('--dataset', default='coco', help='dataset') From 562f530c01f038b3efb0a467a85ce0f3b201cf23 Mon Sep 17 00:00:00 2001 From: Dylan Flaute Date: Thu, 11 Jul 2019 11:38:20 -0400 Subject: [PATCH 2/2] Use raw doc because of backslash. --- references/detection/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/train.py b/references/detection/train.py index 5f928859adf..7152f293b0f 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -1,4 +1,4 @@ -"""PyTorch Detection Training. +r"""PyTorch Detection Training. To run in a multi-gpu environment, use the distributed launcher::