diff --git a/train.py b/train.py index 461e2990..b9c7c196 100644 --- a/train.py +++ b/train.py @@ -126,7 +126,7 @@ def main(): checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=args.save_last_network, every_n_epochs=args.network_save_period, save_top_k=-1) trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=tb_logger) - main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(trainer.root_gpu) + main_device = trainer.root_device if trainer.strategy.root_device.index is None else 'cuda:' + str(trainer.strategy.root_device.index) nnue.to(device=main_device)