From 1e0c4a14d573536d95febdaa7bddf15080e80d9b Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 1 Jun 2022 21:37:51 +0800 Subject: [PATCH] fix(nyz): fix policy set device bug --- ding/policy/base_policy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index a30b080015..441d5e7f1a 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -75,7 +75,6 @@ def __init__( if len(set(self._enable_field).intersection(set(['learn']))) > 0: self._rank = get_rank() if self._cfg.learn.multi_gpu else 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() if self._cfg.learn.multi_gpu: bp_update_sync = self._cfg.learn.get('bp_update_sync', True) @@ -84,7 +83,6 @@ def __init__( else: self._rank = 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() self._model = model self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu'