From 4cad8e77eab547528a5abb908be4a0d811054dcf Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Wed, 2 Mar 2022 20:38:56 +0800 Subject: [PATCH] [Enhancement] Add xpu option (#1815) --- train.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 8b01bef614..35526683c6 100644 --- a/train.py +++ b/train.py @@ -120,6 +120,12 @@ def parse_args(): help='The option of train profiler. If profiler_options is not None, the train ' \ 'profiler is enabled. Refer to the paddleseg/utils/train_profiler.py for details.' ) + parser.add_argument( + '--device', + dest='device', + help='Device place to be set, which can be GPU, XPU, CPU', + default='gpu', + type=str) return parser.parse_args() @@ -137,8 +143,13 @@ def main(args): ['-' * 48]) logger.info(info) - place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ - 'GPUs used'] else 'cpu' + if args.device == 'gpu' and env_info[ + 'Paddle compiled with cuda'] and env_info['GPUs used']: + place = 'gpu' + elif args.device == 'xpu' and paddle.is_compiled_with_xpu(): + place = 'xpu' + else: + place = 'cpu' paddle.set_device(place) if not args.cfg: