Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BreakEpochException #2759

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions composer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from composer.core.serializable import Serializable
from composer.core.state import State
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.core.types import JSON, Batch, BreakEpochException, Dataset, MemoryFormat, PyTorchScheduler, TrainerMode
from composer.core.types import JSON, Batch, Dataset, MemoryFormat, PyTorchScheduler, TrainerMode

__all__ = [
'Algorithm',
Expand All @@ -46,6 +46,5 @@
'JSON',
'MemoryFormat',
'TrainerMode',
'BreakEpochException',
'validate_eval_automicrobatching',
]
10 changes: 1 addition & 9 deletions composer/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from composer.utils import StringEnum

__all__ = ['Batch', 'PyTorchScheduler', 'JSON', 'MemoryFormat', 'TrainerMode', 'BreakEpochException']
__all__ = ['Batch', 'PyTorchScheduler', 'JSON', 'MemoryFormat', 'TrainerMode']

Batch = Any

Expand All @@ -37,14 +37,6 @@
JSON = Union[str, float, int, None, List['JSON'], Dict[str, 'JSON']]


class BreakEpochException(Exception):
"""Raising this exception will immediately end the current epoch.

If you're wondering whether you should use this, the answer is no.
"""
pass


class TrainerMode(StringEnum):
"""Enum to represent which mode the Trainer is in.

Expand Down
262 changes: 129 additions & 133 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
from torchmetrics import Metric

from composer.callbacks import CheckpointSaver, OptimizerMonitor
from composer.core import (Algorithm, AlgorithmPass, Batch, BreakEpochException, Callback, DataSpec, Engine, Evaluator,
Event, Precision, PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode,
ensure_data_spec, ensure_evaluator, ensure_time, get_precision_context,
validate_eval_automicrobatching)
from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision,
PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec,
ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching)
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger,
RemoteUploaderDownloader, WandBLogger)
Expand Down Expand Up @@ -2019,143 +2018,140 @@ def _train_loop(self) -> None:
last_wct = datetime.datetime.now()

while self.state.timestamp < self.state.max_duration:
try:
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})

dataloader = self.state.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(int(self.state.timestamp.epoch))

for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
# Spin dataloader forward unless dataloader handles internally with dataset_resumption
if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int(
self.state.timestamp.batch_in_epoch):
# Restore the RNG state immediately before the next batch is yielded from the dataloader
if batch_idx + 1 == int(self.state.timestamp.batch_in_epoch) and self._rng_state is not None:
reproducibility.load_rng_state(self._rng_state)
self._rng_state = None
continue

self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)

if self.state.deepspeed_enabled:
self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision)

self.engine.run_event(Event.AFTER_DATALOADER)

self.engine.run_event(Event.BATCH_START)

# Log time values
self.logger.log_metrics({
'time/batch': self.state.timestamp.batch.value,
'time/sample': self.state.timestamp.sample.value,
'time/batch_in_epoch': self.state.timestamp.batch_in_epoch.value,
'time/sample_in_epoch': self.state.timestamp.sample_in_epoch.value,
})
if rank_num_tokens > 0:
self.logger.log_metrics({'time/token': self.state.timestamp.token.value})
self.logger.log_metrics({'time/token_in_epoch': self.state.timestamp.token_in_epoch.value})

total_loss_dict = self._train_batch(use_grad_scaling)

if use_grad_scaling:
self.state.scaler.update()

# total_loss_dict can be None if gradient scaling failed
if total_loss_dict is not None:
map_collection(total_loss_dict, dist.all_reduce)
total_loss_dict = {
k: loss.cpu().item() / dist.get_world_size() for k, loss in total_loss_dict.items()
}
self.state.total_loss_dict = total_loss_dict
self.logger.log_metrics(total_loss_dict)

# The scheduler step.step() and compute_and_log_metrics() are going to be included in the
# next batch's wall clock time. The time accumulation must be done here so schedulers
# have the latest timing information

now = datetime.datetime.now()

batch_time = now - last_wct

total_num_samples, total_num_tokens, batch_time = self._accumulate_time_across_ranks(
rank_num_samples,
rank_num_tokens,
batch_time,
dataloader = self.state.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(int(self.state.timestamp.epoch))

for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
# Spin dataloader forward unless dataloader handles internally with dataset_resumption
if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int(
self.state.timestamp.batch_in_epoch):
# Restore the RNG state immediately before the next batch is yielded from the dataloader
if batch_idx + 1 == int(self.state.timestamp.batch_in_epoch) and self._rng_state is not None:
reproducibility.load_rng_state(self._rng_state)
self._rng_state = None
continue

self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)

if self.state.deepspeed_enabled:
self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision)

self.engine.run_event(Event.AFTER_DATALOADER)

self.engine.run_event(Event.BATCH_START)

# Log time values
self.logger.log_metrics({
'time/batch': self.state.timestamp.batch.value,
'time/sample': self.state.timestamp.sample.value,
'time/batch_in_epoch': self.state.timestamp.batch_in_epoch.value,
'time/sample_in_epoch': self.state.timestamp.sample_in_epoch.value,
})
if rank_num_tokens > 0:
self.logger.log_metrics({'time/token': self.state.timestamp.token.value})
self.logger.log_metrics({'time/token_in_epoch': self.state.timestamp.token_in_epoch.value})

total_loss_dict = self._train_batch(use_grad_scaling)

if use_grad_scaling:
self.state.scaler.update()

# total_loss_dict can be None if gradient scaling failed
if total_loss_dict is not None:
map_collection(total_loss_dict, dist.all_reduce)
total_loss_dict = {
k: loss.cpu().item() / dist.get_world_size() for k, loss in total_loss_dict.items()
}
self.state.total_loss_dict = total_loss_dict
self.logger.log_metrics(total_loss_dict)

# The scheduler step.step() and compute_and_log_metrics() are going to be included in the
# next batch's wall clock time. The time accumulation must be done here so schedulers
# have the latest timing information

now = datetime.datetime.now()

batch_time = now - last_wct

total_num_samples, total_num_tokens, batch_time = self._accumulate_time_across_ranks(
rank_num_samples,
rank_num_tokens,
batch_time,
)

# `now` is actually in the past, but want to include the time it takes to perform this reduction
last_wct = now

if self._scheduler_step_frequency == TimeUnit.BATCH:
for scheduler in self.state.schedulers:
scheduler.step()

if self.state.train_metrics is not None:
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)

# `now` is actually in the past, but want to include the time it takes to perform this reduction
last_wct = now
self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_batch(
samples=total_num_samples,
tokens=total_num_tokens,
duration=batch_time,
)

self.engine.run_event(Event.BATCH_END)

if self._scheduler_step_frequency == TimeUnit.BATCH:
for scheduler in self.state.schedulers:
scheduler.step()
# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.BATCH_END)
last_wct = datetime.datetime.now() - duration

if self.state.train_metrics is not None:
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)
self.engine.run_event(Event.BATCH_CHECKPOINT)

if self.state.timestamp >= self.state.max_duration:
# If max_duration is specified in batches, samples, or tokens, and
# and the max_duration is reached mid-epoch, then break out of the dataloader
# to finish the epoch early and finish training.
finished_epoch_early = True
break

self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_batch(
samples=total_num_samples,
tokens=total_num_tokens,
duration=batch_time,
if not finished_epoch_early or self.state.dataloader_len == self.state.timestamp.batch_in_epoch:
# Trigger the epoch end events if the dataloader was exhausted.
# This happens if the "break" did not trigger above, or if it
# did (e.g. duration specified in samples/batches/tokens), but it is still
# the end of the dataloader (i.e. next(dataloader) would raise StopIteration)
if self.state.train_metrics is not None:
self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics)
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)

self.engine.run_event(Event.BATCH_END)

# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.BATCH_END)
last_wct = datetime.datetime.now() - duration

self.engine.run_event(Event.BATCH_CHECKPOINT)

if self.state.timestamp >= self.state.max_duration:
# If max_duration is specified in batches, samples, or tokens, and
# and the max_duration is reached mid-epoch, then break out of the dataloader
# to finish the epoch early and finish training.
finished_epoch_early = True
break

if not finished_epoch_early or self.state.dataloader_len == self.state.timestamp.batch_in_epoch:
# Trigger the epoch end events if the dataloader was exhausted.
# This happens if the "break" did not trigger above, or if it
# did (e.g. duration specified in samples/batches/tokens), but it is still
# the end of the dataloader (i.e. next(dataloader) would raise StopIteration)
if self.state.train_metrics is not None:
self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics)
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)

if self._scheduler_step_frequency == TimeUnit.EPOCH:
for scheduler in self.state.schedulers:
scheduler.step()

self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_epoch()

self.engine.run_event(Event.EPOCH_END)

# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.EPOCH_END)
last_wct = datetime.datetime.now() - duration

self.engine.run_event(Event.EPOCH_CHECKPOINT)
except BreakEpochException:
log.info(f'Skipping the rest of Epoch {int(self.state.timestamp.epoch)}')
if self._scheduler_step_frequency == TimeUnit.EPOCH:
for scheduler in self.state.schedulers:
scheduler.step()

self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_epoch()

self.engine.run_event(Event.EPOCH_END)

# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.EPOCH_END)
last_wct = datetime.datetime.now() - duration

self.engine.run_event(Event.EPOCH_CHECKPOINT)

# Log final time values
self.logger.log_metrics({
Expand Down
Loading