Skip to content

Commit

Permalink
Add fixes for tests (#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Aug 27, 2021
1 parent 47eb2aa commit 4046278
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pl_bolts/callbacks/verification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
input_array = deepcopy(input_array)

if isinstance(self.model, LightningModule):
input_array = self.model.transfer_batch_to_device(input_array, self.model.device)
input_array = self.model.transfer_batch_to_device(input_array, self.model.device, dataloader_idx=0)
else:
input_array = move_data_to_device(input_array, device=next(self.model.parameters()).device)

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/advantage_actor_critic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def cli_main() -> None:
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True)

seed_everything(123)
trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback)
trainer = Trainer.from_argparse_args(args, deterministic=True, callbacks=checkpoint_callback)
trainer.fit(model)


Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def cli_main():
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True)

seed_everything(123)
trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback)
trainer = Trainer.from_argparse_args(args, deterministic=True, callbacks=checkpoint_callback)

trainer.fit(model)

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/reinforce_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def cli_main():
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True)

seed_everything(123)
trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback)
trainer = Trainer.from_argparse_args(args, deterministic=True, callbacks=checkpoint_callback)
trainer.fit(model)


Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/vanilla_policy_gradient_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def cli_main():
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True)

seed_everything(123)
trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback)
trainer = Trainer.from_argparse_args(args, deterministic=True, callbacks=checkpoint_callback)
trainer.fit(model)


Expand Down

0 comments on commit 4046278

Please sign in to comment.