From 396de72245cc49f3f9a237e6b815d17a7630b67c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 13 Sep 2023 13:22:10 -0700 Subject: [PATCH] fix dupliate gpu check --- demo/nvflare/vertical/custom/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/demo/nvflare/vertical/custom/trainer.py b/demo/nvflare/vertical/custom/trainer.py index efe3207341c6..b6c3855ef10f 100644 --- a/demo/nvflare/vertical/custom/trainer.py +++ b/demo/nvflare/vertical/custom/trainer.py @@ -83,9 +83,8 @@ def _do_training(self, fl_ctx: FLContext): 'eval_metric': 'auc', } if self._use_gpus: - if self._use_gpus: - self.log_info(fl_ctx, f'Training with GPU {rank}') - param['device'] = f"cuda:{rank}" + self.log_info(fl_ctx, f'Training with GPU {rank}') + param['device'] = f"cuda:{rank}" # specify validations set to watch performance watchlist = [(dtest, "eval"), (dtrain, "train")]