diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index b21d54569e73d..b6c60bb1a7eee 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -90,7 +90,7 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp self.deterministic = deterministic self.precision = precision - self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None + self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None self.amp_level = amp_level self.cluster_environment = cluster_environment self.is_slurm_managing_tasks = False diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 91c4867d949b1..0f516e2b0b046 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -182,18 +182,12 @@ def __load_weights_on_main_process(self) -> None: def xmp_spawn_kwargs(self): return { "args": (self.lightning_module, trainer, self.mp_queue), - "nproc": len(self.parallel_devices), - "start_method": self.start_method - } + "nproc": len(self.parallel_devices), + "start_method": self.start_method + } def start_training(self, trainer) -> None: - xmp.spawn( - self.new_process, - **self.xmp_spawn_kwargs - ) + xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_testing(self, trainer) -> None: - xmp.spawn( - self.new_process, - **self.xmp_spawn_kwargs - ) + xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)