Skip to content

Commit

Permalink
Fix root_gpu crash with pytorch-lightning >= 1.8
Browse files Browse the repository at this point in the history
AttributeError: `Trainer.root_gpu` was deprecated in
v1.6 and is no longer accessible as of v1.8. Please use
`Trainer.strategy.root_device.index` instead.

Support for Trainer.root_gpu was removed in:
Lightning-AI/pytorch-lightning#11994
  • Loading branch information
linrock authored and vondele committed Jan 1, 2023
1 parent 4061d45 commit 0bc0c78
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main():
checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=args.save_last_network, every_n_epochs=args.network_save_period, save_top_k=-1)
trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=tb_logger)

main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(trainer.root_gpu)
main_device = trainer.root_device if trainer.strategy.root_device.index is None else 'cuda:' + str(trainer.strategy.root_device.index)

nnue.to(device=main_device)

Expand Down

0 comments on commit 0bc0c78

Please sign in to comment.