Skip to content

Commit

Permalink
ref: organize args 1/n (#3435)
Browse files Browse the repository at this point in the history
* ref: organize args 1/n

* ref: organize args 1/n
  • Loading branch information
williamFalcon committed Sep 10, 2020
1 parent 0b5f417 commit 49290a5
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 55 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def lightning_loop(cls_model, num_runs=10, num_epochs=10):
)
trainer.fit(model)

final_loss = trainer.running_loss.last().item()
final_loss = trainer.train_loop.running_loss.last().item()
errors.append(final_loss)

time_end = time.perf_counter()
Expand Down
53 changes: 53 additions & 0 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,65 @@
from pytorch_lightning import accelerators
import os
import torch
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities import rank_zero_warn


class AcceleratorConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus):
self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
self.trainer.on_tpu = self.trainer.tpu_cores is not None

self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None

if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
self.trainer.num_processes = num_processes

# override with environment flag
gpus = os.environ.get('PL_TRAINER_GPUS', gpus)

# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.trainer.gpus = self.trainer.tuner.pick_multiple_gpus(gpus)
else:
self.trainer.gpus = gpus

self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(self.trainer.gpus)
self.trainer.root_gpu = device_parser.determine_root_gpu_device(self.trainer.data_parallel_device_ids)
self.trainer.root_device = torch.device("cpu")

self.trainer.on_gpu = True if (self.trainer.data_parallel_device_ids and torch.cuda.is_available()) else False

# tpu state flags
self.trainer.use_tpu = False
self.trainer.tpu_local_core_rank = None
self.trainer.tpu_global_core_rank = None

# distributed backend choice
self.trainer.distributed_backend = distributed_backend
self.trainer.set_distributed_mode(distributed_backend)

# override dist backend when using tpus
if self.trainer.on_tpu:
self.trainer.distributed_backend = 'tpu'
self.trainer.init_tpu()

# init flags for SLURM+DDP to work
self.trainer.world_size = 1
self.trainer.interactive_ddp_procs = []
self.trainer.configure_slurm_ddp(self.trainer.num_nodes)
self.trainer.node_rank = self.trainer.determine_ddp_node_rank()
self.trainer.local_rank = self.trainer.determine_local_rank()
self.trainer.global_rank = 0

# NVIDIA setup
self.trainer.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids)

def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@ def get_progress_bar_dict(self):
Dictionary with the items to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_train_loss = self.trainer.running_loss.mean()
running_train_loss = self.trainer.train_loop.running_loss.mean()
avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
tqdm_dict = {'loss': '{:.3f}'.format(avg_training_loss)}

Expand Down
64 changes: 14 additions & 50 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
Expand Down Expand Up @@ -186,7 +185,6 @@ def __init__(

# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningAccum(window_length=20)
self.batch_idx = 0
self.num_training_batches = 0
self.num_val_batches = []
Expand Down Expand Up @@ -220,6 +218,9 @@ def __init__(
self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir

# -------------------------------
# CALLBACK INITS
# -------------------------------
# init callbacks
self.callbacks = callbacks or []

Expand Down Expand Up @@ -260,15 +261,18 @@ def __init__(
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.track_grad_norm = float(track_grad_norm)

self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
self.on_tpu = self.tpu_cores is not None

self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores, list) else None

if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
self.num_processes = num_processes
# init accelerator related flags
self.accelerator_connector.on_trainer_init(
num_processes,
tpu_cores,
distributed_backend,
auto_select_gpus,
gpus
)

# -------------------
# CONTINUE
# -------------------
self.weights_summary = weights_summary

self.max_epochs = max_epochs
Expand Down Expand Up @@ -313,46 +317,6 @@ def __init__(
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

# override with environment flag
gpus = os.environ.get('PL_TRAINER_GPUS', gpus)

# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = self.tuner.pick_multiple_gpus(gpus)
else:
self.gpus = gpus

self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
self.root_gpu = device_parser.determine_root_gpu_device(self.data_parallel_device_ids)
self.root_device = torch.device("cpu")

self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False

# tpu state flags
self.use_tpu = False
self.tpu_local_core_rank = None
self.tpu_global_core_rank = None

# distributed backend choice
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend)

# override dist backend when using tpus
if self.on_tpu:
self.distributed_backend = 'tpu'
self.init_tpu()

# init flags for SLURM+DDP to work
self.world_size = 1
self.interactive_ddp_procs = []
self.configure_slurm_ddp(self.num_nodes)
self.node_rank = self.determine_ddp_node_rank()
self.local_rank = self.determine_local_rank()
self.global_rank = 0

# NVIDIA setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)

self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

# logging
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, trainer):
self.checkpoint_accumulator = None
self.accumulated_loss = None
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)

@property
def num_optimizers(self):
Expand Down Expand Up @@ -503,7 +504,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

# calculate running loss for display
self.trainer.running_loss.append(
self.running_loss.append(
self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches
)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id
if self.progress_bar:
self.progress_bar.update()

current_loss = trainer.running_loss.last().item()
current_loss = trainer.train_loop.running_loss.last().item()
current_step = trainer.global_step + 1 # remove the +1 in 1.0

# Avg loss (loss with momentum) + smoothing
Expand Down Expand Up @@ -445,7 +445,7 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id
if self.progress_bar:
self.progress_bar.update()

current_loss = trainer.running_loss.last().item()
current_loss = trainer.train_loop.running_loss.last().item()
current_step = trainer.global_step + 1 # remove the +1 in 1.0

# Avg loss (loss with momentum) + smoothing
Expand Down

0 comments on commit 49290a5

Please sign in to comment.