Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a FlyteCallback #23759

Merged
merged 7 commits into from
May 30, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy as np

from . import __version__ as version
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
from .utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
from .utils.versions import importlib_metadata


Expand Down Expand Up @@ -146,6 +146,10 @@ def is_codecarbon_available():
return importlib.util.find_spec("codecarbon") is not None


def is_flytekit_available():
return importlib.util.find_spec("flytekit") is not None


def hp_params(trial):
if is_optuna_available():
import optuna
Expand Down Expand Up @@ -1537,6 +1541,63 @@ def on_save(self, args, state, control, **kwargs):
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False)


class FlyteCallback(TrainerCallback):
"""A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
NOTE: This callback only works within a Flyte task.

Args:
save_log_history (bool, optional, defaults to `True`):
peridotml marked this conversation as resolved.
Show resolved Hide resolved
Determines whether or not to save the training logs to the task's Flyte Deck.

sync_checkpoints (bool, optional, defaults to `True`):
peridotml marked this conversation as resolved.
Show resolved Hide resolved
When set to True, checkpoints are synced to Flyte and can be used to resume training in the case of an
interruption.

Example:
```python
# Note: This example skips over some setup steps for brevity.
from flytekit import current_context, task


@task
def train_hf_transformer():
cp = current_context().checkpoint
trainer = Trainer(..., callbacks=[FlyteCallback()])
output = trainer.train(resume_from_checkpoint=cp.restore())
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be unindented of one level, and should have an additional line between Example and the three backticks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

"""

def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
super().__init__()
if not is_flytekit_available():
raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")

from flytekit import current_context

self.cp = current_context().checkpoint
self.save_log_history = save_log_history
self.sync_checkpoints = sync_checkpoints

def on_save(self, args, state, control, **kwargs):
if self.sync_checkpoints and state.is_world_process_zero:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)

logger.info(f"Saving checkpoint in {ckpt_dir} to Flyte. This may take time.")
self.cp.save(artifact_path)

def on_train_end(self, args, state, control, **kwargs):
if self.save_log_history:
import pandas as pd
peridotml marked this conversation as resolved.
Show resolved Hide resolved
from flytekit import Deck

if is_pandas_available():
Deck("Log History", pd.DataFrame(state.log_history).to_html())
peridotml marked this conversation as resolved.
Show resolved Hide resolved
else:
logger.warning("Install pandas for optimal Flyte log formatting.")
Deck("Log History", str(state.log_history))
peridotml marked this conversation as resolved.
Show resolved Hide resolved


INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
Expand All @@ -1547,6 +1608,7 @@ def on_save(self, args, state, control, **kwargs):
"codecarbon": CodeCarbonCallback,
"clearml": ClearMLCallback,
"dagshub": DagsHubCallback,
"flyte": FlyteCallback,
}


Expand Down