diff --git a/demo/nvflare/vertical/custom/trainer.py b/demo/nvflare/vertical/custom/trainer.py index efe3207341c6..b6c3855ef10f 100644 --- a/demo/nvflare/vertical/custom/trainer.py +++ b/demo/nvflare/vertical/custom/trainer.py @@ -83,9 +83,8 @@ def _do_training(self, fl_ctx: FLContext): 'eval_metric': 'auc', } if self._use_gpus: - if self._use_gpus: - self.log_info(fl_ctx, f'Training with GPU {rank}') - param['device'] = f"cuda:{rank}" + self.log_info(fl_ctx, f'Training with GPU {rank}') + param['device'] = f"cuda:{rank}" # specify validations set to watch performance watchlist = [(dtest, "eval"), (dtrain, "train")]