Skip to content

Commit

Permalink
[Enhancement] Add xpu option (#1815)
Browse files Browse the repository at this point in the history
  • Loading branch information
ykkk2333 authored Mar 2, 2022
1 parent 4ef3fb4 commit 4cad8e7
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def parse_args():
help='The option of train profiler. If profiler_options is not None, the train ' \
'profiler is enabled. Refer to the paddleseg/utils/train_profiler.py for details.'
)
parser.add_argument(
'--device',
dest='device',
help='Device place to be set, which can be GPU, XPU, CPU',
default='gpu',
type=str)

return parser.parse_args()

Expand All @@ -137,8 +143,13 @@ def main(args):
['-' * 48])
logger.info(info)

place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
'GPUs used'] else 'cpu'
if args.device == 'gpu' and env_info[
'Paddle compiled with cuda'] and env_info['GPUs used']:
place = 'gpu'
elif args.device == 'xpu' and paddle.is_compiled_with_xpu():
place = 'xpu'
else:
place = 'cpu'

paddle.set_device(place)
if not args.cfg:
Expand Down

0 comments on commit 4cad8e7

Please sign in to comment.